From ee85ddd394052efaaa35352e2567b39b6d2a215e Mon Sep 17 00:00:00 2001 From: swathipil <76007337+swathipil@users.noreply.github.com> Date: Sun, 7 May 2023 15:14:30 -0500 Subject: [PATCH] [SB] merge pyproto main2 (#30269) * upstream + sb pyamqp * pyamqp from eventhub * asyncio markers on async tests * remove recordings * updating sender and session from main * missing req * [Service Bus] Performance Tests (#28399) * delete t1 tests * perf tests * fix size * few small fixes for topic * perf tests redone * stream tests * update logic to handle recv & del * restructure classes * formatting * preload only delta messages * fixes * rename test files * swathis comments * perf bicep * fix transport & remove shared * if peeklock, then complete * make batch receive batchperftests * restructure classes * fix * fixes for max_count * fix wording * minor clean up * move add args to mixin * [ServiceBus] Iterator Support (#28558) * iterator * add todo * remove logger * tests * skip serialization and pyamp transport errors * add in keep-alive for releasing messages * try this for client.py * trying to fix iterator vs normal receiving timeout for releasing messages * remove part of test failing for sync release -- doesnt work on pyamqp * make message yeilding a while loop * copying changes to async * fixing async release test * message_received was set incorrectly/ diff than uamqp * fix async * fixing closing order logic * receive_contxt, fix test * another receive_context * asyncio.lock * receiver_context * try to ignore sync for now * async with lock * receive_context * remove print statements * unskip sync * fix pylint and mypy * skip tests except for failing one * mark not mock * socket read timeout set to 1 but was .2 on sync * run all tests with new timeout * pr comments * pr comments * typing add back * remove todo * time to live - ttl * async init * async init * pr comments * Revert "pr comments" This reverts commit f8c96f25f2e50e708a55ed44b36cbf6e10cb5001. * pr comments - lock rename * pr comments - remove timeout setting for uamqp * tests * set link credit in connection listen on keep alive * link async * change pop to get * try flow before connection * test * link credit not bein kept bc of flow in client_run * if link credit is 0, reset it * add wait_time to tests * time override test * Connection to Link Error * sleep fix * link_credit * missing _ in test * remove boolean flag * need actibvity timestamp in yeild message * missing "_" * stamp ->timestamp * dont fix EH here * formatting pylint * sock timeout async * whitespace * remove __aiter__ * need iter * remove todo * todo * whitespace * iter_context * remove self * tests - remove #pytest.skip * add timeouts to constants * merge together if statements * timeout * missing if * pylint * pylint * whitespace * todo * if statement :) * pylint/remove whitespace * [ServiceBus] merge EH pyamqp into SB pyamqp (#29223) * [TEMPORARY] adding eh _pyamqp folder * [TEMP] add _pyamqp/aio * undo removing client lock * lint * [ServiceBus] update pyproto b1 version/changelog/readme (#29251) * [ServiceBus] update changelog 7.9.0b1 (#29267) * Increment version for servicebus releases (#29268) Increment package version after release of azure-servicebus * [ServiceBus] Pyamqp Changes from EH in SB (#29499) * frame fix sync * frame fix async * [ServiceBus] uamqp/pyamqp switch (#28512) - Added _pyamqp_transport.py/_uamqp_transport.py, which contain all corresponding uamqp/pyamqp code. TODO: - tests: - [ ] manually create ServiceBusMessageBatch and set client to `uamqp_transport=True` when sending. * [ServiceBus] Files for SB Perf Tests (#29503) * files for perf test CI * fix args, vals based on comments * remove unused params * add batch size back in to perf test * add in add_arguments in to send base * adjust message size * fix * [Service Bus] Fix System.Byte[] Not Supported (#29670) * add in string decode for 161 * app keys are now strings * app keys are now strings * remove decode * revert changes * decode if app props val is bytes * remove change from pyamqp layer * move fix in to outgoing message * remove extra ) * fix * mypy and pylint * fix in back_compat * rephrase * [ServiceBus] Fix sb perf test (#29765) * fixes for perf bicep * fix env vars and params * [ServiceBus] prep release 7.10.0b1 (#29815) * prep for release * update readmes * update docs for switch * remove uamqp dev req temporarily * fix mypy/pylint * add back uamqp to dev reqs * merge main in topyamqp * update release date * Increment package version after release of azure-servicebus (#29881) * [SB] Remove references to internal streaming method (#29750) * remove streaming * update stress * move uamqp transport imports into client constructor (#29921) * [ServiceBus][Perf] Fix perf tests (#30004) Some async tests were trying to use an synchronous receiver/sender. This fixes that so that the tests can run. Signed-off-by: Paul Van Eck * [SB Pyamqp] stress updates (#29783) * stress updates * changes * add memray to stress * undo docker file changes * add memray chaos * timeoutError raise * devred * try log to file * test indiv * test indv * updates * tests * logging_enable * stress * update * delete * remove changes to code * change level * update chart.yaml * update to local running of indv components * updates * remove * update * update test base * remove eh changes * logging * update jpb * update docker * update scenarios * logging --------- Co-authored-by: swathipil * [ServiceBus] Update tracing (#29995) * [ServiceBus] Update tracing - "Send" span now contains links to message spans. - Receive span is now kind CLIENT instead of CONSUMER. - Added span creation logic for settlement methods. - Attribute names were updated to align with distributed tracing conventions. - Some span named renamed to align with other SDKs. - Receive spans now have more accurate start times. Signed-off-by: Paul Van Eck * Refactor tracing utils Signed-off-by: Paul Van Eck * Remove unneeded arg from trace_message Signed-off-by: Paul Van Eck * update changelog Signed-off-by: Paul Van Eck * Remove use of `messaging.source.name` This is slated to be removed in favor of `messaging.destination.name` for everything. Here, we maintain use of the legacy attribute name `message_bus.destination`. Signed-off-by: Paul Van Eck * remove test-resources.bicep from stress --------- Signed-off-by: Paul Van Eck Co-authored-by: swathipil * [ServiceBus] Fix Memory Leak on Network Drop + Use Asyncio Streams (#29904) * use non blocking socket + raise on errno 110 * pylint fixes * comment for errno 110 * [ServiceBus] pyamqp exception parity (#30020) * unskip tests * test passing uamqp.TransportType * make sbreceived messages picklable * edge case sb message batch creation test * remove accidental additions * add sb client tests * add invalid custom endpoint tests * update pyamqp invalid custom endpoint error * add test_errors back to folder * add more tests * lint * fix unskipped async test * kashif comments * fix asyncio pickling for <3.11 * lint * unpickle clients * remove receiver/uamqp message from received message pickling * annas comments * update version + changelog * update amqp transport kind check in message * changelog + update to stable * pull main again * update readme/typing * test session set_state None --------- Signed-off-by: Paul Van Eck Co-authored-by: l0lawrence Co-authored-by: Kashif Khan <361477+kashifkhan@users.noreply.github.com> Co-authored-by: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Co-authored-by: Paul Van Eck --- sdk/servicebus/azure-servicebus/CHANGELOG.md | 42 +- sdk/servicebus/azure-servicebus/README.md | 29 + .../azure/servicebus/__init__.py | 5 +- .../azure/servicebus/_base_handler.py | 204 +-- .../servicebus/_common/_configuration.py | 40 +- .../servicebus/_common/auto_lock_renewer.py | 6 +- .../azure/servicebus/_common/constants.py | 63 +- .../azure/servicebus/_common/message.py | 442 ++++--- .../azure/servicebus/_common/mgmt_handlers.py | 68 +- .../servicebus/_common/receiver_mixins.py | 112 +- .../azure/servicebus/_common/tracing.py | 302 +++++ .../azure/servicebus/_common/utils.py | 245 ++-- .../azure/servicebus/_pyamqp/__init__.py | 21 + .../azure/servicebus/_pyamqp/_connection.py | 856 +++++++++++++ .../azure/servicebus/_pyamqp/_decode.py | 349 ++++++ .../azure/servicebus/_pyamqp/_encode.py | 920 ++++++++++++++ .../servicebus/_pyamqp/_message_backcompat.py | 258 ++++ .../azure/servicebus/_pyamqp/_platform.py | 107 ++ .../azure/servicebus/_pyamqp/_transport.py | 805 ++++++++++++ .../azure/servicebus/_pyamqp/aio/__init__.py | 35 + .../_pyamqp/aio/_authentication_async.py | 70 ++ .../servicebus/_pyamqp/aio/_cbs_async.py | 260 ++++ .../servicebus/_pyamqp/aio/_client_async.py | 965 +++++++++++++++ .../_pyamqp/aio/_connection_async.py | 874 +++++++++++++ .../servicebus/_pyamqp/aio/_link_async.py | 260 ++++ .../_pyamqp/aio/_management_link_async.py | 249 ++++ .../aio/_management_operation_async.py | 140 +++ .../servicebus/_pyamqp/aio/_receiver_async.py | 124 ++ .../servicebus/_pyamqp/aio/_sasl_async.py | 149 +++ .../servicebus/_pyamqp/aio/_sender_async.py | 203 +++ .../servicebus/_pyamqp/aio/_session_async.py | 460 +++++++ .../_pyamqp/aio/_transport_async.py | 547 ++++++++ .../servicebus/_pyamqp/authentication.py | 175 +++ .../azure/servicebus/_pyamqp/cbs.py | 299 +++++ .../azure/servicebus/_pyamqp/client.py | 1058 ++++++++++++++++ .../azure/servicebus/_pyamqp/constants.py | 341 +++++ .../azure/servicebus/_pyamqp/endpoints.py | 280 +++++ .../azure/servicebus/_pyamqp/error.py | 356 ++++++ .../azure/servicebus/_pyamqp/link.py | 259 ++++ .../servicebus/_pyamqp/management_link.py | 262 ++++ .../_pyamqp/management_operation.py | 140 +++ .../azure/servicebus/_pyamqp/message.py | 272 ++++ .../azure/servicebus/_pyamqp/outcomes.py | 160 +++ .../azure/servicebus/_pyamqp/performatives.py | 634 ++++++++++ .../azure/servicebus/_pyamqp/receiver.py | 121 ++ .../azure/servicebus/_pyamqp/sasl.py | 146 +++ .../azure/servicebus/_pyamqp/sender.py | 200 +++ .../azure/servicebus/_pyamqp/session.py | 507 ++++++++ .../azure/servicebus/_pyamqp/types.py | 90 ++ .../azure/servicebus/_pyamqp/utils.py | 138 +++ .../azure/servicebus/_servicebus_client.py | 64 +- .../azure/servicebus/_servicebus_receiver.py | 265 ++-- .../azure/servicebus/_servicebus_sender.py | 244 ++-- .../azure/servicebus/_servicebus_session.py | 18 +- .../azure/servicebus/_transport/__init__.py | 4 + .../azure/servicebus/_transport/_base.py | 333 +++++ .../_transport/_pyamqp_transport.py | 954 ++++++++++++++ .../servicebus/_transport/_uamqp_transport.py | 1099 +++++++++++++++++ .../azure/servicebus/_version.py | 2 +- .../aio/_async_auto_lock_renewer.py | 4 +- .../azure/servicebus/aio/_async_utils.py | 56 +- .../servicebus/aio/_base_handler_async.py | 118 +- .../aio/_servicebus_client_async.py | 54 +- .../aio/_servicebus_receiver_async.py | 295 ++--- .../aio/_servicebus_sender_async.py | 178 ++- .../aio/_servicebus_session_async.py | 2 +- .../servicebus/aio/_transport/__init__.py | 4 + .../servicebus/aio/_transport/_base_async.py | 295 +++++ .../aio/_transport/_pyamqp_transport_async.py | 384 ++++++ .../aio/_transport/_uamqp_transport_async.py | 332 +++++ .../management/_management_client_async.py | 20 +- .../management/_shared_key_policy_async.py | 2 +- .../azure/servicebus/amqp/_amqp_message.py | 204 +-- .../azure/servicebus/amqp/_amqp_utils.py | 25 + .../azure/servicebus/amqp/_constants.py | 9 - .../azure/servicebus/exceptions.py | 218 ---- .../management/_management_client.py | 2 +- .../management/_shared_key_policy.py | 2 +- sdk/servicebus/azure-servicebus/conftest.py | 2 - .../azure-servicebus/dev_requirements.txt | 2 + sdk/servicebus/azure-servicebus/setup.py | 1 - .../azure-servicebus/stress/.helmignore | 6 + .../azure-servicebus/stress/Chart.lock | 4 +- .../azure-servicebus/stress/Chart.yaml | 2 +- .../azure-servicebus/stress/Dockerfile | 4 +- .../stress/scenarios-matrix.yaml | 33 +- .../stress/scripts/dev_requirements.txt | 4 +- .../azure-servicebus/stress/scripts/logger.py | 70 +- .../stress/scripts/process_monitor.py | 2 +- .../stress/scripts/stress_runner.py | 18 +- .../stress/scripts/stress_test_base.py | 194 +-- .../stress/scripts/test_stress_queues.py | 310 +++-- .../scripts/test_stress_queues_async.py | 399 ++++++ .../stress/stress-test-resources.bicep | 57 +- .../stress/templates/network_loss.yaml | 25 - .../stress/templates/testjob.yaml | 66 +- .../tests/async_tests/test_queues_async.py | 855 ++++++++----- .../tests/async_tests/test_sb_client_async.py | 290 +++-- .../tests/async_tests/test_sessions_async.py | 219 +++- .../async_tests/test_subscriptions_async.py | 36 +- .../tests/async_tests/test_topic_async.py | 22 +- .../tests/livetest/test_errors.py | 44 - .../perf_tests/T1_legacy_tests/__init__.py | 0 .../perf_tests/T1_legacy_tests/_test_base.py | 153 --- .../T1_legacy_tests/receive_message_batch.py | 27 - .../T1_legacy_tests/send_message.py | 25 - .../T1_legacy_tests/send_message_batch.py | 26 - .../T1_legacy_tests/t1_test_requirements.txt | 1 - .../tests/perf_tests/_test_base.py | 320 +++-- .../tests/perf_tests/receive_message_batch.py | 31 - .../perf_tests/receive_queue_message_batch.py | 29 + ...eam.py => receive_queue_message_stream.py} | 8 +- .../receive_subscription_message_batch.py | 27 + ...=> receive_subscription_message_stream.py} | 14 +- .../tests/perf_tests/send_message.py | 23 - .../tests/perf_tests/send_message_batch.py | 40 - .../tests/perf_tests/send_queue_message.py | 35 + .../perf_tests/send_queue_message_batch.py | 31 + .../tests/perf_tests/send_topic_message.py | 35 + .../perf_tests/send_topic_message_batch.py | 30 + .../tests/servicebus_preparer.py | 32 +- .../tests/test_connection_string_parser.py | 4 +- .../azure-servicebus/tests/test_message.py | 724 ++++++++++- .../azure-servicebus/tests/test_queues.py | 900 +++++++++----- .../azure-servicebus/tests/test_sb_client.py | 287 +++-- .../azure-servicebus/tests/test_sessions.py | 231 ++-- .../tests/test_subscriptions.py | 43 +- .../azure-servicebus/tests/test_topic.py | 30 +- .../tests/unittests/test_errors.py | 81 ++ .../azure-servicebus/tests/utilities.py | 57 +- sdk/servicebus/perf-resources.bicep | 68 + sdk/servicebus/perf-tests.yml | 44 + sdk/servicebus/perf.yml | 36 + sdk/servicebus/tests.yml | 2 +- 134 files changed, 22405 insertions(+), 3489 deletions(-) create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/__init__.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_connection.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_message_backcompat.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_platform.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/__init__.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_authentication_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_cbs_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_connection_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_link_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_link_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_operation_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_receiver_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sasl_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sender_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_session_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/authentication.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/cbs.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/constants.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/endpoints.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/error.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/link.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_link.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_operation.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/message.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/outcomes.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/performatives.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/receiver.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sasl.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sender.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/session.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/types.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/utils.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_transport/__init__.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_base.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_uamqp_transport.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/__init__.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_base_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_uamqp_transport_async.py create mode 100644 sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_utils.py create mode 100644 sdk/servicebus/azure-servicebus/stress/.helmignore create mode 100644 sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues_async.py delete mode 100644 sdk/servicebus/azure-servicebus/stress/templates/network_loss.yaml delete mode 100644 sdk/servicebus/azure-servicebus/tests/livetest/test_errors.py delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/__init__.py delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/_test_base.py delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/receive_message_batch.py delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message.py delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message_batch.py delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/t1_test_requirements.txt delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/receive_message_batch.py create mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/receive_queue_message_batch.py rename sdk/servicebus/azure-servicebus/tests/perf_tests/{receive_message_stream.py => receive_queue_message_stream.py} (85%) create mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/receive_subscription_message_batch.py rename sdk/servicebus/azure-servicebus/tests/perf_tests/{T1_legacy_tests/receive_message_stream.py => receive_subscription_message_stream.py} (76%) delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/send_message.py delete mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/send_message_batch.py create mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message.py create mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message_batch.py create mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message.py create mode 100644 sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message_batch.py create mode 100644 sdk/servicebus/azure-servicebus/tests/unittests/test_errors.py create mode 100644 sdk/servicebus/perf-resources.bicep create mode 100644 sdk/servicebus/perf-tests.yml create mode 100644 sdk/servicebus/perf.yml diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index 7a182a2f26001..8ab2ea6377057 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -1,15 +1,53 @@ # Release History -## 7.9.1 (Unreleased) +## 7.10.0 (2023-05-09) + +Version 7.10.0 is our first stable release of the Azure Service Bus client library based on a pure Python implemented AMQP stack. ### Features Added -### Breaking Changes +- A new boolean keyword argument `uamqp_transport` has been added to sync and async `ServiceBusClient` constructors which indicates whether to use the `uamqp` library or the default pure Python AMQP library as the underlying transport. ### Bugs Fixed +- Fixed a bug where sync and async `ServiceBusAdministrationClient` expected `credential` with `get_token` method returning `AccessToken.token` of type `bytes` and not `str`, now matching the documentation. +- Fixed a bug where `raw_amqp_message.header` and `message.header` properties on `ServiceReceivedBusMessage` were returned with `durable`, `first_acquirer`, and `priority` properties set by default, rather than the values returned by the service. +- Fixed a bug where `ServiceBusReceivedMessage` was not picklable (Issue #27947). + ### Other Changes +- The `message` attribute on `ServiceBus`/`ServiceBusMessageBatch`/`ServiceBusReceivedMessage`, which previously exposed the `uamqp.Message`/`uamqp.BatchMessage`, has been deprecated. + - `LegacyMessage`/`LegacyBatchMessage` objects returned by the `message` attribute on `ServiceBus`/`ServiceBusMessageBatch` have been introduced to help facilitate the transition. +- Removed uAMQP from required dependencies. +- Adding `uamqp >= 1.6.3` as an optional dependency for use with the `uamqp_transport` keyword. + - Updated tracing ([#29995](https://github.com/Azure/azure-sdk-for-python/pull/29995)): + - Additional attributes added to existing spans: + - `messaging.system` - messaging system (i.e., `servicebus`) + - `messaging.operation` - type of operation (i.e., `publish`, `receive`, or `settle`) + - `messaging.batch.message_count` - number of messages sent or received (if more than one) + - A span will now be created upon calls to the service that settle messages. + - The span name will contain the settlement operation (e.g., `ServiceBus.complete`) + - The span will contain `az.namespace`, `messaging.destination.name`, `net.peer.name`, `messaging.system`, and `messaging.operation` attributes. + - All `send` spans now contain links to `message` spans. Now, `message` spans will no longer contain a link to the `send` span. + +## 7.10.0b1 (2023-04-13) + +### Features Added + +- A new boolean keyword argument `uamqp_transport` has been added to sync and async `ServiceBusClient` constructors which indicates whether to use the `uamqp` library or the default pure Python AMQP library as the underlying transport. + +### Bugs Fixed + +- Fixed a bug where sync and async `ServiceBusAdministrationClient` expected `credential` with `get_token` method returning `AccessToken.token` of type `bytes` and not `str`, now matching the documentation. +- Fixed a bug where `raw_amqp_message.header` and `message.header` properties on `ServiceReceivedBusMessage` were returned with `durable`, `first_acquirer`, and `priority` properties set by default, rather than the values returned by the service. + +### Other Changes + +- The `message` attribute on `ServiceBus`/`ServiceBusMessageBatch`/`ServiceBusReceivedMessage`, which previously exposed the `uamqp.Message`/`uamqp.BatchMessage`, has been deprecated. + - `LegacyMessage`/`LegacyBatchMessage` objects returned by the `message` attribute on `ServiceBus`/`ServiceBusMessageBatch` have been introduced to help facilitate the transition. +- Removed uAMQP from required dependencies. +- Adding `uamqp >= 1.6.3` as an optional dependency for use with the `uamqp_transport` keyword. + ## 7.9.0 (2023-04-11) ### Breaking Changes diff --git a/sdk/servicebus/azure-servicebus/README.md b/sdk/servicebus/azure-servicebus/README.md index 5fe4de6ae4b2e..e4e0d349dfa94 100644 --- a/sdk/servicebus/azure-servicebus/README.md +++ b/sdk/servicebus/azure-servicebus/README.md @@ -480,12 +480,41 @@ For users seeking to perform management operations against ServiceBus (Creating please see the [azure-mgmt-servicebus documentation][service_bus_mgmt_docs] for API documentation. Terse usage examples can be found [here](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/servicebus/azure-mgmt-servicebus/tests) as well. +### Pure Python AMQP Transport and Backward Compatibility Support + +The Azure Service Bus client library is now based on a pure Python AMQP implementation. `uAMQP` has been removed as required dependency. + +To use `uAMQP` as the underlying transport: + +1. Install `uamqp` with pip. + +``` +$ pip install uamqp +``` + +2. Pass `uamqp_transport=True` during client construction. + +```python +from azure.servicebus import ServiceBusClient +connection_str = '<< CONNECTION STRING FOR THE SERVICE BUS NAMESPACE >>' +queue_name = '<< NAME OF THE QUEUE >>' +client = ServiceBusClient.from_connection_string( + connection_str, uamqp_transport=True +) +``` + +Note: The `message` attribute on `ServiceBusMessage`/`ServiceBusMessageBatch`/`ServiceBusReceivedMessage`, which previously exposed the `uamqp.Message`, has been deprecated. + The "Legacy" objects returned by `message` attribute have been introduced to help facilitate the transition. + ### Building uAMQP wheel from source `azure-servicebus` depends on the [uAMQP](https://pypi.org/project/uamqp/) for the AMQP protocol implementation. uAMQP wheels are provided for most major operating systems and will be installed automatically when installing `azure-servicebus`. +If [uAMQP](https://pypi.org/project/uamqp/) is intended to be used as the underlying AMQP protocol implementation for `azure-servicebus`, +uAMQP wheels can be found for most major operating systems. If you're running on a platform for which uAMQP wheels are not provided, please follow +If you intend to use `uAMQP` and you're running on a platform for which uAMQP wheels are not provided, please follow the [uAMQP Installation](https://github.com/Azure/azure-uamqp-python#installation) guidance to install from source. ## Contributing diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py b/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py index 511fe13765634..2ab2fa09dfd47 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py @@ -3,8 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- -from uamqp import constants - from ._version import VERSION __version__ = VERSION @@ -23,6 +21,7 @@ ServiceBusSubQueue, ServiceBusMessageState, ServiceBusSessionFilter, + TransportType, NEXT_AVAILABLE_SESSION, ) from ._common.auto_lock_renewer import AutoLockRenewer @@ -31,8 +30,6 @@ ServiceBusConnectionStringProperties, ) -TransportType = constants.TransportType - __all__ = [ "ServiceBusMessage", "ServiceBusMessageBatch", diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py index 81605eebb83de..e0fce68eab6ae 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py @@ -8,27 +8,22 @@ import threading from datetime import timedelta from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable, Union +from azure.core.credentials import AccessToken, AzureSasCredential, AzureNamedKeyCredential +from azure.core.pipeline.policies import RetryMode try: - from urllib.parse import quote_plus, urlparse + from urllib.parse import urlparse except ImportError: - from urllib import quote_plus # type: ignore from urlparse import urlparse # type: ignore -import uamqp -from uamqp import utils, compat -from uamqp.message import MessageProperties - -from azure.core.credentials import AccessToken, AzureSasCredential, AzureNamedKeyCredential -from azure.core.pipeline.policies import RetryMode +from ._pyamqp.utils import generate_sas_token +from ._transport._pyamqp_transport import PyamqpTransport from ._common._configuration import Configuration from .exceptions import ( - ServiceBusError, ServiceBusConnectionError, OperationTimeoutError, SessionLockLostError, - _create_servicebus_exception, ) from ._common.utils import create_properties, strip_protocol_from_uri, parse_sas_credential from ._common.constants import ( @@ -37,27 +32,33 @@ TOKEN_TYPE_SASTOKEN, MGMT_REQUEST_OP_TYPE_ENTITY_MGMT, ASSOCIATEDLINKPROPERTYNAME, - TRACE_NAMESPACE_PROPERTY, - TRACE_COMPONENT_PROPERTY, - TRACE_COMPONENT, - TRACE_PEER_ADDRESS_PROPERTY, - TRACE_BUS_DESTINATION_PROPERTY, ) if TYPE_CHECKING: + from .exceptions import ServiceBusError + try: + # pylint:disable=unused-import + from uamqp import AMQPClient as uamqp_AMQPClientSync + except ImportError: + pass + + from ._pyamqp.message import Message as pyamqp_Message + from ._pyamqp.client import AMQPClient as pyamqp_AMQPClientSync from azure.core.credentials import TokenCredential _LOGGER = logging.getLogger(__name__) -def _parse_conn_str(conn_str, check_case=False): - # type: (str, Optional[bool]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] +def _parse_conn_str( + conn_str: str, + check_case: Optional[bool] = False +) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]: endpoint = None shared_access_key_name = None shared_access_key = None - entity_path = None # type: Optional[str] - shared_access_signature = None # type: Optional[str] - shared_access_signature_expiry = None # type: Optional[int] + entity_path: Optional[str]= None + shared_access_signature: Optional[str] = None + shared_access_signature_expiry: Optional[int] = None # split connection string into properties conn_properties = [s.split("=", 1) for s in conn_str.strip().rstrip(";").split(";")] @@ -136,8 +137,9 @@ def _parse_conn_str(conn_str, check_case=False): ) -def _generate_sas_token(uri, policy, key, expiry=None): - # type: (str, str, str, Optional[timedelta]) -> AccessToken +def _generate_sas_token( + uri: str, policy: str, key: str, expiry: Optional[timedelta] = None +) -> AccessToken: """Create a shared access signiture token as a string literal. :returns: SAS token as string literal. :rtype: str @@ -146,11 +148,7 @@ def _generate_sas_token(uri, policy, key, expiry=None): expiry = timedelta(hours=1) # Default to 1 hour. abs_expiry = int(time.time()) + expiry.seconds - encoded_uri = quote_plus(uri).encode("utf-8") # pylint: disable=no-member - encoded_policy = quote_plus(policy).encode("utf-8") # pylint: disable=no-member - encoded_key = key.encode("utf-8") - - token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry) + token = generate_sas_token(uri, policy, key, abs_expiry) return AccessToken(token=token, expires_on=abs_expiry) def _get_backoff_time(retry_mode, backoff_factor, backoff_max, retried_times): @@ -166,8 +164,7 @@ class ServiceBusSASTokenCredential(object): :param int expiry: The epoch timestamp """ - def __init__(self, token, expiry): - # type: (str, int) -> None + def __init__(self, token: str, expiry: int) -> None: """ :param str token: The shared access token string :param float expiry: The epoch timestamp @@ -176,8 +173,7 @@ def __init__(self, token, expiry): self.expiry = expiry self.token_type = b"servicebus.windows.net:sastoken" - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument """ This method is automatically called when token is about to expire. """ @@ -191,14 +187,12 @@ class ServiceBusSharedKeyCredential(object): :param str key: The shared access key. """ - def __init__(self, policy, key): - # type: (str, str) -> None + def __init__(self, policy: str, key: str) -> None: self.policy = policy self.key = key self.token_type = TOKEN_TYPE_SASTOKEN - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument if not scopes: raise ValueError("No token scope provided.") return _generate_sas_token(scopes[0], self.policy, self.key) @@ -210,13 +204,11 @@ class ServiceBusAzureNamedKeyTokenCredential(object): :type credential: ~azure.core.credentials.AzureNamedKeyCredential """ - def __init__(self, azure_named_key_credential): - # type: (AzureNamedKeyCredential) -> None + def __init__(self, azure_named_key_credential: AzureNamedKeyCredential) -> None: self._credential = azure_named_key_credential self.token_type = b"servicebus.windows.net:sastoken" - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument if not scopes: raise ValueError("No token scope provided.") name, key = self._credential.named_key @@ -229,13 +221,11 @@ class ServiceBusAzureSasTokenCredential(object): :param azure_sas_credential: The credential to be used for authentication. :type azure_sas_credential: ~azure.core.credentials.AzureSasCredential """ - def __init__(self, azure_sas_credential): - # type: (AzureSasCredential) -> None + def __init__(self, azure_sas_credential: AzureSasCredential) -> None: self._credential = azure_sas_credential self.token_type = b"servicebus.windows.net:sastoken" - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument """ This method is automatically called when token is about to expire. """ @@ -244,8 +234,15 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument class BaseHandler: # pylint:disable=too-many-instance-attributes - def __init__(self, fully_qualified_namespace, entity_name, credential, **kwargs): - # type: (str, str, Union[TokenCredential, AzureSasCredential, AzureNamedKeyCredential], Any) -> None + def __init__( + self, + fully_qualified_namespace: str, + entity_name: str, + credential: Union["TokenCredential", AzureSasCredential, AzureNamedKeyCredential], + **kwargs: Any + ) -> None: + self._amqp_transport = kwargs.pop("amqp_transport", PyamqpTransport) + # If the user provided http:// or sb://, let's be polite and strip that. self.fully_qualified_namespace = strip_protocol_from_uri( fully_qualified_namespace.strip() @@ -256,7 +253,7 @@ def __init__(self, fully_qualified_namespace, entity_name, credential, **kwargs) self._entity_path = self._entity_name + ( ("/Subscriptions/" + subscription_name) if subscription_name else "" ) - self._mgmt_target = "{}{}".format(self._entity_path, MANAGEMENT_PATH_SUFFIX) + self._mgmt_target = f"{self._entity_path}{MANAGEMENT_PATH_SUFFIX}" if isinstance(credential, AzureSasCredential): self._credential = ServiceBusAzureSasTokenCredential(credential) elif isinstance(credential, AzureNamedKeyCredential): @@ -264,16 +261,24 @@ def __init__(self, fully_qualified_namespace, entity_name, credential, **kwargs) else: self._credential = credential # type: ignore self._container_id = CONTAINER_PREFIX + str(uuid.uuid4())[:8] - self._config = Configuration(**kwargs) + self._config = Configuration( + hostname=self.fully_qualified_namespace, + amqp_transport=self._amqp_transport, + **kwargs + ) self._running = False - self._handler = None # type: uamqp.AMQPClient + self._handler: Optional[Union["uamqp_AMQPClientSync", "pyamqp_AMQPClientSync"]] = None self._auth_uri = None - self._properties = create_properties(self._config.user_agent) + self._properties = create_properties( + self._config.user_agent, + amqp_transport=self._amqp_transport, + ) self._shutdown = threading.Event() @classmethod - def _convert_connection_string_to_kwargs(cls, conn_str, **kwargs): - # type: (str, Any) -> Dict[str, Any] + def _convert_connection_string_to_kwargs( + cls, conn_str: str, **kwargs: Any + ) -> Dict[str, Any]: host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str( conn_str ) @@ -297,10 +302,8 @@ def _convert_connection_string_to_kwargs(cls, conn_str, **kwargs): and (entity_in_conn_str != entity_in_kwargs) ): raise ValueError( - "The queue or topic name provided: {} which does not match the EntityPath in" - " the connection string passed to the ServiceBusClient constructor: {}.".format( - entity_in_conn_str, entity_in_kwargs - ) + f"The queue or topic name provided: {entity_in_conn_str} which does not match the EntityPath in" + f" the connection string passed to the ServiceBusClient constructor: {entity_in_kwargs}." ) kwargs["fully_qualified_namespace"] = host @@ -330,16 +333,18 @@ def __enter__(self): def __exit__(self, *args): self.close() - def _handle_exception(self, exception): - # type: (BaseException) -> ServiceBusError + def _handle_exception(self, exception: BaseException) -> "ServiceBusError": # pylint: disable=protected-access, line-too-long - error = _create_servicebus_exception(_LOGGER, exception) + error = self._amqp_transport.create_servicebus_exception( + _LOGGER, exception, custom_endpoint_address=self._config.custom_endpoint_address + ) try: - # If SessionLockLostError or ServiceBusConnectionError happen when a session receiver is running, - # the receiver should no longer be used and should create a new session receiver - # instance to receive from session. There are pitfalls WRT both next session IDs, - # and the diversity of session failure modes, that motivates us to disallow this. + # If SessionLockLostError or ServiceBusConnectionError happen when a + # session receiver is running, the receiver should no longer be used and + # should create a new session receiver instance to receive from session. + # There are pitfalls WRT both next session IDs, and the diversity of session + # failure modes, that motivates us to disallow this. if self._session and self._running and isinstance(error, (SessionLockLostError, ServiceBusConnectionError)): # type: ignore self._session._lock_lost = True # type: ignore self._close_handler() @@ -380,8 +385,12 @@ def _check_live(self): except AttributeError: pass - def _do_retryable_operation(self, operation, timeout=None, **kwargs): - # type: (Callable, Optional[float], Any) -> Any + def _do_retryable_operation( + self, + operation: Callable, + timeout: Optional[float] = None, + **kwargs: Any + ) -> Any: # pylint: disable=protected-access require_last_exception = kwargs.pop("require_last_exception", False) operation_requires_timeout = kwargs.pop("operation_requires_timeout", False) @@ -402,6 +411,9 @@ def _do_retryable_operation(self, operation, timeout=None, **kwargs): return operation(**kwargs) except StopIteration: raise + except ImportError: + # If dependency is not installed, do not retry. + raise except Exception as exception: # pylint: disable=broad-except last_exception = self._handle_exception(exception) if require_last_exception: @@ -453,14 +465,13 @@ def _backoff( def _mgmt_request_response( self, - mgmt_operation, - message, - callback, - keep_alive_associated_link=True, - timeout=None, - **kwargs - ): - # type: (bytes, Any, Callable, bool, Optional[float], Any) -> uamqp.Message + mgmt_operation: bytes, + message: Any, + callback: Callable, + keep_alive_associated_link: bool = True, + timeout: Optional[float] = None, + **kwargs: Any + ) -> "pyamqp_Message": """ Execute an amqp management operation. @@ -470,11 +481,11 @@ def _mgmt_request_response( :param message: The message to send in the management request. :paramtype message: Any :param callback: The callback which is used to parse the returning message. - :paramtype callback: Callable[int, ~uamqp.message.Message, str] + :paramtype callback: Callable[int, Union[~uamqp.message.Message, Message], str] :param keep_alive_associated_link: A boolean flag for keeping associated amqp sender/receiver link alive when executing operation on mgmt links. :param timeout: timeout in seconds executing the mgmt operation. - :rtype: None + :rtype: Tuple """ self._open() application_properties = {} @@ -483,36 +494,42 @@ def _mgmt_request_response( if keep_alive_associated_link: try: application_properties = { - ASSOCIATEDLINKPROPERTYNAME: self._handler.message_handler.name + ASSOCIATEDLINKPROPERTYNAME: self._amqp_transport.get_handler_link_name(self._handler) } except AttributeError: pass - mgmt_msg = uamqp.Message( - body=message, - properties=MessageProperties( - reply_to=self._mgmt_target, encoding=self._config.encoding, **kwargs - ), + mgmt_msg = self._amqp_transport.create_mgmt_msg( + message=message, application_properties=application_properties, + config=self._config, + reply_to=self._mgmt_target, + **kwargs ) + try: - return self._handler.mgmt_request( + return self._amqp_transport.mgmt_client_request( + self._handler, mgmt_msg, - mgmt_operation, - op_type=MGMT_REQUEST_OP_TYPE_ENTITY_MGMT, + operation=mgmt_operation, + operation_type=MGMT_REQUEST_OP_TYPE_ENTITY_MGMT, node=self._mgmt_target.encode(self._config.encoding), - timeout=timeout * 1000 if timeout else None, - callback=callback, + timeout=timeout, + callback=callback ) except Exception as exp: # pylint: disable=broad-except - if isinstance(exp, compat.TimeoutException): + if isinstance(exp, self._amqp_transport.TIMEOUT_ERROR): raise OperationTimeoutError(error=exp) raise def _mgmt_request_response_with_retry( - self, mgmt_operation, message, callback, timeout=None, **kwargs - ): - # type: (bytes, Dict[str, Any], Callable, Optional[float], Any) -> Any + self, + mgmt_operation: bytes, + message: Dict[str, Any], + callback: Callable, + timeout: Optional[float] = None, + **kwargs: Any + ) -> Any: return self._do_retryable_operation( self._mgmt_request_response, mgmt_operation=mgmt_operation, @@ -523,12 +540,6 @@ def _mgmt_request_response_with_retry( **kwargs ) - def _add_span_request_attributes(self, span): - span.add_attribute(TRACE_COMPONENT_PROPERTY, TRACE_COMPONENT) - span.add_attribute(TRACE_NAMESPACE_PROPERTY, TRACE_NAMESPACE_PROPERTY) - span.add_attribute(TRACE_BUS_DESTINATION_PROPERTY, self._entity_path) - span.add_attribute(TRACE_PEER_ADDRESS_PROPERTY, self.fully_qualified_namespace) - def _open(self): # pylint: disable=no-self-use raise ValueError("Subclass should override the method.") @@ -541,8 +552,7 @@ def _close_handler(self): self._handler = None self._running = False - def close(self): - # type: () -> None + def close(self) -> None: """Close down the handler links (and connection if the handler uses a separate connection). If the handler has already closed, this operation will do nothing. diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py index a445e497b6120..a788414a9df6d 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py @@ -1,30 +1,34 @@ -# -------------------------------------------------------------------------------------------- +# # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union, TYPE_CHECKING from urllib.parse import urlparse -from uamqp.constants import TransportType, DEFAULT_AMQP_WSS_PORT, DEFAULT_AMQPS_PORT from azure.core.pipeline.policies import RetryMode - +from .constants import DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT, TransportType +if TYPE_CHECKING: + from .._transport._base import AmqpTransport + from ..aio._transport._base_async import AmqpTransportAsync class Configuration(object): # pylint:disable=too-many-instance-attributes def __init__(self, **kwargs): - self.user_agent = kwargs.get("user_agent") # type: Optional[str] - self.retry_total = kwargs.get("retry_total", 3) # type: int + self.user_agent: Optional[str] = kwargs.get("user_agent") + self.retry_total: int = kwargs.get("retry_total", 3) self.retry_mode = RetryMode(kwargs.get("retry_mode", 'exponential')) - self.retry_backoff_factor = kwargs.get( + self.retry_backoff_factor: float = kwargs.get( "retry_backoff_factor", 0.8 - ) # type: float - self.retry_backoff_max = kwargs.get("retry_backoff_max", 120) # type: int - self.logging_enable = kwargs.get("logging_enable", False) # type: bool - self.http_proxy = kwargs.get("http_proxy") # type: Optional[Dict[str, Any]] + ) + self.retry_backoff_max: int = kwargs.get("retry_backoff_max", 120) + self.logging_enable: bool = kwargs.get("logging_enable", False) + self.http_proxy: Optional[Dict[str, Any]] = kwargs.get("http_proxy") - self.custom_endpoint_address = kwargs.get("custom_endpoint_address") # type: Optional[str] - self.connection_verify = kwargs.get("connection_verify") # type: Optional[str] + self.custom_endpoint_address: Optional[str] = kwargs.get("custom_endpoint_address") + self.connection_verify: Optional[str] = kwargs.get("connection_verify") self.connection_port = DEFAULT_AMQPS_PORT self.custom_endpoint_hostname = None + self.hostname = kwargs.pop("hostname") + amqp_transport: Union["AmqpTransport", "AmqpTransportAsync"] = kwargs.pop("amqp_transport") self.transport_type = ( TransportType.AmqpOverWebsocket @@ -32,15 +36,17 @@ def __init__(self, **kwargs): else kwargs.get("transport_type", TransportType.Amqp) ) # The following configs are not public, for internal usage only - self.auth_timeout = kwargs.get("auth_timeout", 60) # type: int + self.auth_timeout: float = kwargs.get("auth_timeout", 60) self.encoding = kwargs.get("encoding", "UTF-8") self.auto_reconnect = kwargs.get("auto_reconnect", True) self.keep_alive = kwargs.get("keep_alive", 30) - self.timeout = kwargs.get("timeout", 60) # type: float + self.timeout: float = kwargs.get("timeout", 60) - if self.http_proxy or self.transport_type == TransportType.AmqpOverWebsocket: + if self.http_proxy or self.transport_type.value == TransportType.AmqpOverWebsocket.value: self.transport_type = TransportType.AmqpOverWebsocket self.connection_port = DEFAULT_AMQP_WSS_PORT + if amqp_transport.KIND == "pyamqp": + self.hostname += "/$servicebus/websocket" # custom end point if self.custom_endpoint_address: @@ -51,5 +57,7 @@ def __init__(self, **kwargs): endpoint = urlparse(self.custom_endpoint_address) self.transport_type = TransportType.AmqpOverWebsocket self.custom_endpoint_hostname = endpoint.hostname + if amqp_transport.KIND == "pyamqp": + self.custom_endpoint_address += "/$servicebus/websocket" # in case proxy and custom endpoint are both provided, we default port to 443 if it's not provided self.connection_port = endpoint.port or DEFAULT_AMQP_WSS_PORT diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/auto_lock_renewer.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/auto_lock_renewer.py index 186e1842e3a4b..34c3f01b6374d 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/auto_lock_renewer.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/auto_lock_renewer.py @@ -9,6 +9,7 @@ import threading import time from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError +import queue from typing import TYPE_CHECKING from .._servicebus_receiver import ServiceBusReceiver @@ -23,11 +24,6 @@ Renewable = Union[ServiceBusSession, ServiceBusReceivedMessage] LockRenewFailureCallback = Callable[[Renewable, Optional[Exception]], None] -try: - import queue -except ImportError: - import Queue as queue # type: ignore - _log = logging.getLogger(__name__) SHORT_RENEW_OFFSET = 0.5 # Seconds that if a renew period is longer than lock duration + offset, it's "too long" diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/constants.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/constants.py index 17abf70b846ce..204895a17475e 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/constants.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/constants.py @@ -4,8 +4,6 @@ # license information. # ------------------------------------------------------------------------- from enum import Enum - -from uamqp import constants, types from azure.core import CaseInsensitiveEnumMeta VENDOR = b"com.microsoft" @@ -61,6 +59,8 @@ JWT_TOKEN_SCOPE = "https://servicebus.azure.net//.default" USER_AGENT_PREFIX = "azsdk-python-servicebus" CONSUMER_IDENTIFIER = VENDOR + b":receiver-name" +UAMQP_LIBRARY = "uamqp" +PYAMQP_LIBRARY = "pyamqp" MANAGEMENT_PATH_SUFFIX = "/$management" @@ -137,31 +137,7 @@ "ServiceBusDlqSupplementaryAuthorization" ) -# Distributed Tracing Constants - -TRACE_COMPONENT_PROPERTY = "component" -TRACE_COMPONENT = "servicebus" - -TRACE_NAMESPACE_PROPERTY = "az.namespace" -TRACE_NAMESPACE = "ServiceBus" - -SPAN_NAME_RECEIVE = TRACE_NAMESPACE + ".receive" -SPAN_NAME_RECEIVE_DEFERRED = TRACE_NAMESPACE + ".receive_deferred" -SPAN_NAME_PEEK = TRACE_NAMESPACE + ".peek" -SPAN_NAME_SEND = TRACE_NAMESPACE + ".send" -SPAN_NAME_SCHEDULE = TRACE_NAMESPACE + ".schedule" -SPAN_NAME_MESSAGE = TRACE_NAMESPACE + ".message" - -TRACE_BUS_DESTINATION_PROPERTY = "message_bus.destination" -TRACE_PEER_ADDRESS_PROPERTY = "peer.address" - -SPAN_ENQUEUED_TIME_PROPERTY = "enqueuedTime" - -TRACE_ENQUEUED_TIME_PROPERTY = b"x-opt-enqueued-time" -TRACE_PARENT_PROPERTY = b"Diagnostic-Id" -TRACE_PROPERTY_ENCODING = "ascii" - - +MAX_MESSAGE_LENGTH_BYTES = 1024 * 1024 # Backcompat with uAMQP MESSAGE_PROPERTY_MAX_LENGTH = 128 # .NET TimeSpan.MaxValue: 10675199.02:48:05.4775807 MAX_DURATION_VALUE = 922337203685477 @@ -178,12 +154,6 @@ class ServiceBusMessageState(int, Enum): DEFERRED = 1 SCHEDULED = 2 -# To enable extensible string enums for the public facing parameter, and translate to the "real" uamqp constants. -ServiceBusToAMQPReceiveModeMap = { - ServiceBusReceiveMode.PEEK_LOCK: constants.ReceiverSettleMode.PeekLock, - ServiceBusReceiveMode.RECEIVE_AND_DELETE: constants.ReceiverSettleMode.ReceiveAndDelete, -} - class ServiceBusSessionFilter(Enum): NEXT_AVAILABLE = 0 @@ -194,17 +164,18 @@ class ServiceBusSubQueue(str, Enum, metaclass=CaseInsensitiveEnumMeta): TRANSFER_DEAD_LETTER = "transferdeadletter" -ANNOTATION_SYMBOL_PARTITION_KEY = types.AMQPSymbol(_X_OPT_PARTITION_KEY) -ANNOTATION_SYMBOL_VIA_PARTITION_KEY = types.AMQPSymbol(_X_OPT_VIA_PARTITION_KEY) -ANNOTATION_SYMBOL_SCHEDULED_ENQUEUE_TIME = types.AMQPSymbol( - _X_OPT_SCHEDULED_ENQUEUE_TIME -) - -ANNOTATION_SYMBOL_KEY_MAP = { - _X_OPT_PARTITION_KEY: ANNOTATION_SYMBOL_PARTITION_KEY, - _X_OPT_VIA_PARTITION_KEY: ANNOTATION_SYMBOL_VIA_PARTITION_KEY, - _X_OPT_SCHEDULED_ENQUEUE_TIME: ANNOTATION_SYMBOL_SCHEDULED_ENQUEUE_TIME, -} - - NEXT_AVAILABLE_SESSION = ServiceBusSessionFilter.NEXT_AVAILABLE + +## all below - previously uamqp +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 + +DEFAULT_AMQPS_PORT = 5671 +DEFAULT_AMQP_WSS_PORT = 443 diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py index 9b547039d84b8..6b902629c79d5 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py @@ -5,14 +5,16 @@ # ------------------------------------------------------------------------- # pylint: disable=too-many-lines +from __future__ import annotations import time +import warnings import datetime import uuid -import logging -from typing import Optional, Dict, List, Union, Iterable, TYPE_CHECKING, Any, Mapping, cast +from typing import Optional, Dict, List, Union, Iterable, Any, Mapping, cast, TYPE_CHECKING -import uamqp.errors -import uamqp.message +from .._pyamqp._message_backcompat import LegacyMessage, LegacyBatchMessage +from .._pyamqp.message import Message as pyamqp_Message +from .._transport._pyamqp_transport import PyamqpTransport from .constants import ( _BATCH_MESSAGE_OVERHEAD_COST, @@ -28,12 +30,10 @@ _X_OPT_DEAD_LETTER_SOURCE, PROPERTIES_DEAD_LETTER_REASON, PROPERTIES_DEAD_LETTER_ERROR_DESCRIPTION, - ANNOTATION_SYMBOL_PARTITION_KEY, - ANNOTATION_SYMBOL_SCHEDULED_ENQUEUE_TIME, - ANNOTATION_SYMBOL_KEY_MAP, MESSAGE_PROPERTY_MAX_LENGTH, MAX_ABSOLUTE_EXPIRY_TIME, MAX_DURATION_VALUE, + MAX_MESSAGE_LENGTH_BYTES, MESSAGE_STATE_NAME ) from ..amqp import ( @@ -46,26 +46,32 @@ from .utils import ( utc_from_timestamp, utc_now, - trace_message, - transform_messages_if_needed, + transform_outbound_messages, ) +from .tracing import trace_message if TYPE_CHECKING: + try: + # pylint:disable=unused-import + from uamqp import ( + Message, + BatchMessage + ) + except ImportError: + pass + from .._pyamqp.performatives import TransferFrame from ..aio._servicebus_receiver_async import ( ServiceBusReceiver as AsyncServiceBusReceiver, ) from .._servicebus_receiver import ServiceBusReceiver - from azure.core.tracing import AbstractSpan - PrimitiveTypes = Union[ - int, - float, - bytes, - bool, - str, - uuid.UUID - ] - -_LOGGER = logging.getLogger(__name__) +PrimitiveTypes = Union[ + int, + float, + bytes, + bool, + str, + uuid.UUID +] class ServiceBusMessage( @@ -106,7 +112,7 @@ def __init__( self, body: Optional[Union[str, bytes]], *, - application_properties: Optional[Dict[str, "PrimitiveTypes"]] = None, + application_properties: Optional[Dict[Union[str, bytes], "PrimitiveTypes"]] = None, session_id: Optional[str] = None, message_id: Optional[str] = None, scheduled_enqueue_time_utc: Optional[datetime.datetime] = None, @@ -123,17 +129,18 @@ def __init__( # Although we might normally thread through **kwargs this causes # problems as MessageProperties won't absorb spurious args. self._encoding = kwargs.pop("encoding", "UTF-8") - - if "raw_amqp_message" in kwargs and "message" in kwargs: - # Internal usage only for transforming AmqpAnnotatedMessage to outgoing ServiceBusMessage - self.message = kwargs["message"] - self._raw_amqp_message = kwargs["raw_amqp_message"] - elif "message" in kwargs: - # Note: This cannot be renamed until UAMQP no longer relies on this specific name. - self.message = kwargs["message"] - self._raw_amqp_message = AmqpAnnotatedMessage(message=self.message) + self._uamqp_message: Optional[Union[LegacyMessage, "Message"]] = None + self._message: Union["Message", "pyamqp_Message"] = None # type: ignore + + # Internal usage only for transforming AmqpAnnotatedMessage to outgoing ServiceBusMessage + if "message" in kwargs: + self._message = kwargs["message"] + if "raw_amqp_message" in kwargs: + self._raw_amqp_message = kwargs["raw_amqp_message"] + else: + self._raw_amqp_message = AmqpAnnotatedMessage(message=kwargs["message"]) else: - self._build_message(body) + self._build_annotated_message(body) self.application_properties = application_properties self.session_id = session_id self.message_id = message_id @@ -147,12 +154,10 @@ def __init__( self.time_to_live = time_to_live self.partition_key = partition_key - def __str__(self): - # type: () -> str + def __str__(self) -> str: return str(self.raw_amqp_message) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: # pylint: disable=bare-except message_repr = "body={}".format( str(self) @@ -207,7 +212,7 @@ def __repr__(self): message_repr += ", scheduled_enqueue_time_utc=" return "ServiceBusMessage({})".format(message_repr)[:1024] - def _build_message(self, body): + def _build_annotated_message(self, body): if not ( isinstance(body, (str, bytes)) or (body is None) ): @@ -225,36 +230,50 @@ def _build_message(self, body): def _set_message_annotations(self, key, value): if not self._raw_amqp_message.annotations: self._raw_amqp_message.annotations = {} - - if isinstance(self, ServiceBusReceivedMessage): - try: - del self._raw_amqp_message.annotations[key] - except KeyError: - pass - if value is None: try: - del self._raw_amqp_message.annotations[ANNOTATION_SYMBOL_KEY_MAP[key]] + del self._raw_amqp_message.annotations[key] except KeyError: pass else: - self._raw_amqp_message.annotations[ANNOTATION_SYMBOL_KEY_MAP[key]] = value + self._raw_amqp_message.annotations[key] = value - def _to_outgoing_message(self): - # type: () -> ServiceBusMessage - # pylint: disable=protected-access - self.message = self.raw_amqp_message._to_outgoing_amqp_message() - return self + @property + def message(self) -> Union["Message", LegacyMessage]: + """DEPRECATED: Get the underlying uamqp.Message or LegacyMessage. + This is deprecated and will be removed in a later release. + + :rtype: uamqp.Message or LegacyMessage + """ + warnings.warn( + "The `message` property is deprecated and will be removed in future versions.", + DeprecationWarning, + ) + if not self._uamqp_message: + self._uamqp_message = LegacyMessage( + self._raw_amqp_message, + to_outgoing_amqp_message=PyamqpTransport.to_outgoing_amqp_message + ) + return self._uamqp_message + + @message.setter + def message(self, value: "Message") -> None: + """DEPRECATED: Set the underlying Message. + This is deprecated and will be removed in a later release. + """ + warnings.warn( + "The `message` property is deprecated and will be removed in future versions.", + DeprecationWarning, + ) + self._uamqp_message = value @property - def raw_amqp_message(self): - # type: () -> AmqpAnnotatedMessage + def raw_amqp_message(self) -> AmqpAnnotatedMessage: """Advanced usage only. The internal AMQP message payload that is sent or received.""" return self._raw_amqp_message @property - def session_id(self): - # type: () -> Optional[str] + def session_id(self) -> Optional[str]: """The session identifier of the message for a sessionful entity. For sessionful entities, this application-defined value specifies the session affiliation of the message. @@ -273,8 +292,7 @@ def session_id(self): return self._raw_amqp_message.properties.group_id @session_id.setter - def session_id(self, value): - # type: (str) -> None + def session_id(self, value: str) -> None: if value and len(value) > MESSAGE_PROPERTY_MAX_LENGTH: raise ValueError( "session_id cannot be longer than {} characters.".format( @@ -288,8 +306,7 @@ def session_id(self, value): self._raw_amqp_message.properties.group_id = value @property - def application_properties(self): - # type: () -> Optional[Dict] + def application_properties(self) -> Optional[Dict[Union[str, bytes], PrimitiveTypes]]: """The user defined properties on the message. :rtype: dict @@ -297,13 +314,11 @@ def application_properties(self): return self._raw_amqp_message.application_properties @application_properties.setter - def application_properties(self, value): - # type: (Dict) -> None + def application_properties(self, value: Dict[Union[str, bytes], Any]) -> None: self._raw_amqp_message.application_properties = value @property - def partition_key(self): - # type: () -> Optional[str] + def partition_key(self) -> Optional[str]: """The partition key for sending a message to a partitioned entity. Setting this value enables assigning related messages to the same internal partition, so that submission @@ -315,24 +330,16 @@ def partition_key(self): :rtype: str """ - p_key = None try: - # opt_p_key is used on the incoming message opt_p_key = self._raw_amqp_message.annotations.get(_X_OPT_PARTITION_KEY) # type: ignore if opt_p_key is not None: - p_key = opt_p_key - # symbol_p_key is used on the outgoing message - symbol_p_key = self._raw_amqp_message.annotations.get(ANNOTATION_SYMBOL_PARTITION_KEY) # type: ignore - if symbol_p_key is not None: - p_key = symbol_p_key - - return p_key.decode("UTF-8") # type: ignore + return opt_p_key.decode("UTF-8") except (AttributeError, UnicodeDecodeError): - return p_key + return opt_p_key + return None @partition_key.setter - def partition_key(self, value): - # type: (str) -> None + def partition_key(self, value: str) -> None: if value and len(value) > MESSAGE_PROPERTY_MAX_LENGTH: raise ValueError( "partition_key cannot be longer than {} characters.".format( @@ -349,8 +356,7 @@ def partition_key(self, value): self._set_message_annotations(_X_OPT_PARTITION_KEY, value) @property - def time_to_live(self): - # type: () -> Optional[datetime.timedelta] + def time_to_live(self) -> Optional[datetime.timedelta]: """The life duration of a message. This value is the relative duration after which the message expires, starting from the instant the message @@ -368,8 +374,7 @@ def time_to_live(self): return None @time_to_live.setter - def time_to_live(self, value): - # type: (datetime.timedelta) -> None + def time_to_live(self, value: Union[datetime.timedelta, int]) -> None: if not self._raw_amqp_message.header: self._raw_amqp_message.header = AmqpMessageHeader() if value is None: @@ -392,8 +397,7 @@ def time_to_live(self, value): ) @property - def scheduled_enqueue_time_utc(self): - # type: () -> Optional[datetime.datetime] + def scheduled_enqueue_time_utc(self) -> Optional[datetime.datetime]: """The utc scheduled enqueue time to the message. This property can be used for scheduling when sending a message through `ServiceBusSender.send` method. @@ -404,9 +408,7 @@ def scheduled_enqueue_time_utc(self): :rtype: ~datetime.datetime """ if self._raw_amqp_message.annotations: - timestamp = self._raw_amqp_message.annotations.get( - _X_OPT_SCHEDULED_ENQUEUE_TIME - ) or self._raw_amqp_message.annotations.get(ANNOTATION_SYMBOL_SCHEDULED_ENQUEUE_TIME) + timestamp = self._raw_amqp_message.annotations.get(_X_OPT_SCHEDULED_ENQUEUE_TIME) if timestamp: try: in_seconds = timestamp / 1000.0 @@ -416,8 +418,7 @@ def scheduled_enqueue_time_utc(self): return None @scheduled_enqueue_time_utc.setter - def scheduled_enqueue_time_utc(self, value): - # type: (datetime.datetime) -> None + def scheduled_enqueue_time_utc(self, value: datetime.datetime) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() if not self._raw_amqp_message.properties.message_id: @@ -425,8 +426,7 @@ def scheduled_enqueue_time_utc(self, value): self._set_message_annotations(_X_OPT_SCHEDULED_ENQUEUE_TIME, value) @property - def body(self): - # type: () -> Any + def body(self) -> Any: """The body of the Message. The format may vary depending on the body type: For :class:`azure.servicebus.amqp.AmqpMessageBodyType.DATA`, the body could be bytes or Iterable[bytes]. @@ -441,8 +441,7 @@ def body(self): return self._raw_amqp_message.body @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. :rtype: ~azure.servicebus.amqp.AmqpMessageBodyType @@ -450,8 +449,7 @@ def body_type(self): return self._raw_amqp_message.body_type @property - def content_type(self): - # type: () -> Optional[str] + def content_type(self) -> Optional[str]: """The content type descriptor. Optionally describes the payload of the message, with a descriptor following the format of RFC2045, Section 5, @@ -467,15 +465,13 @@ def content_type(self): return self._raw_amqp_message.properties.content_type @content_type.setter - def content_type(self, value): - # type: (str) -> None + def content_type(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.content_type = value @property - def correlation_id(self): - # type: () -> Optional[str] + def correlation_id(self) -> Optional[str]: # pylint: disable=line-too-long """The correlation identifier. @@ -495,15 +491,13 @@ def correlation_id(self): return self._raw_amqp_message.properties.correlation_id @correlation_id.setter - def correlation_id(self, value): - # type: (str) -> None + def correlation_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.correlation_id = value @property - def subject(self): - # type: () -> Optional[str] + def subject(self) -> Optional[str]: """The application specific subject, sometimes referred to as a label. This property enables the application to indicate the purpose of the message to the receiver in a standardized @@ -519,15 +513,13 @@ def subject(self): return self._raw_amqp_message.properties.subject @subject.setter - def subject(self, value): - # type: (str) -> None + def subject(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.subject = value @property - def message_id(self): - # type: () -> Optional[str] + def message_id(self) -> Optional[str]: """The id to identify the message. The message identifier is an application-defined value that uniquely identifies the message and its payload. @@ -546,8 +538,7 @@ def message_id(self): return self._raw_amqp_message.properties.message_id @message_id.setter - def message_id(self, value): - # type: (str) -> None + def message_id(self, value: str) -> None: if value and len(str(value)) > MESSAGE_PROPERTY_MAX_LENGTH: raise ValueError( "message_id cannot be longer than {} characters.".format( @@ -559,8 +550,7 @@ def message_id(self, value): self._raw_amqp_message.properties.message_id = value @property - def reply_to(self): - # type: () -> Optional[str] + def reply_to(self) -> Optional[str]: # pylint: disable=line-too-long """The address of an entity to send replies to. @@ -581,15 +571,13 @@ def reply_to(self): return self._raw_amqp_message.properties.reply_to @reply_to.setter - def reply_to(self, value): - # type: (str) -> None + def reply_to(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.reply_to = value @property - def reply_to_session_id(self): - # type: () -> Optional[str] + def reply_to_session_id(self) -> Optional[str]: # pylint: disable=line-too-long """The session identifier augmenting the `reply_to` address. @@ -609,8 +597,7 @@ def reply_to_session_id(self): return self._raw_amqp_message.properties.reply_to_group_id @reply_to_session_id.setter - def reply_to_session_id(self, value): - # type: (str) -> None + def reply_to_session_id(self, value: str) -> None: if value and len(value) > MESSAGE_PROPERTY_MAX_LENGTH: raise ValueError( "reply_to_session_id cannot be longer than {} characters.".format( @@ -623,8 +610,7 @@ def reply_to_session_id(self, value): self._raw_amqp_message.properties.reply_to_group_id = value @property - def to(self): - # type: () -> Optional[str] + def to(self) -> Optional[str]: """The `to` address. This property is reserved for future use in routing scenarios and presently ignored by the broker itself. @@ -643,8 +629,7 @@ def to(self): return self._raw_amqp_message.properties.to @to.setter - def to(self, value): - # type: (str) -> None + def to(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.to = value @@ -666,49 +651,51 @@ class ServiceBusMessageBatch(object): can hold. """ - def __init__(self, max_size_in_bytes=None): - # type: (Optional[int]) -> None - self.message = uamqp.BatchMessage( - data=[], multi_messages=False, properties=None - ) - self._max_size_in_bytes = ( - max_size_in_bytes or uamqp.constants.MAX_MESSAGE_LENGTH_BYTES - ) - self._size = self.message.gather()[0].get_message_encoded_size() + def __init__( + self, + max_size_in_bytes: Optional[int] = None, + **kwargs: Any + ) -> None: + self._amqp_transport = kwargs.pop("amqp_transport", PyamqpTransport) + self._tracing_attributes: Dict[str, Union[str, int]] = kwargs.pop("tracing_attributes", {}) + + self._max_size_in_bytes = max_size_in_bytes or MAX_MESSAGE_LENGTH_BYTES + self._message = self._amqp_transport.build_batch_message([]) + self._size = self._amqp_transport.get_batch_message_encoded_size(self._message) self._count = 0 - self._messages = [] # type: List[ServiceBusMessage] + self._messages: List[ServiceBusMessage] = [] + self._uamqp_message: Optional[LegacyBatchMessage] = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: batch_repr = "max_size_in_bytes={}, message_count={}".format( self.max_size_in_bytes, self._count ) return "ServiceBusMessageBatch({})".format(batch_repr) - def __len__(self): - # type: () -> int + def __len__(self) -> int: return self._count - def _from_list(self, messages: Iterable[ServiceBusMessage], parent_span: Optional["AbstractSpan"] = None) -> None: + def _from_list(self, messages: Iterable[ServiceBusMessage]) -> None: for message in messages: - self._add(message, parent_span) + self._add(message) - def _add( - self, - add_message: Union[ServiceBusMessage, Mapping[str, Any], AmqpAnnotatedMessage], - parent_span: Optional["AbstractSpan"] = None - ) -> None: + def _add(self, add_message: Union[ServiceBusMessage, Mapping[str, Any], AmqpAnnotatedMessage]) -> None: """Actual add implementation. The shim exists to hide the internal parameters such as parent_span.""" - message = transform_messages_if_needed(add_message, ServiceBusMessage) - message = cast(ServiceBusMessage, message) - trace_message( - message, parent_span - ) # parent_span is e.g. if built as part of a send operation. - message_size = ( - message.message.get_message_encoded_size() + outgoing_sb_message = transform_outbound_messages( + add_message, ServiceBusMessage, self._amqp_transport.to_outgoing_amqp_message + ) + outgoing_sb_message = cast(ServiceBusMessage, outgoing_sb_message) + # pylint: disable=protected-access + outgoing_sb_message._message = trace_message( + outgoing_sb_message._message, + amqp_transport=self._amqp_transport, + additional_attributes=self._tracing_attributes + ) + message_size = self._amqp_transport.get_message_encoded_size( + outgoing_sb_message._message # pylint: disable=protected-access ) - # For a ServiceBusMessageBatch, if the encoded_message_size of event_data is < 256, then the overhead cost to + # For a ServiceBusMessageBatch, if the encoded_message_size of message is < 256, then the overhead cost to # encode that message into the ServiceBusMessageBatch would be 5 bytes, if >= 256, it would be 8 bytes. size_after_add = ( self._size @@ -718,19 +705,48 @@ def _add( if size_after_add > self.max_size_in_bytes: raise MessageSizeExceededError( - message="ServiceBusMessageBatch has reached its size limit: {}".format( - self.max_size_in_bytes - ) + message=f"ServiceBusMessageBatch has reached its size limit: {self.max_size_in_bytes}" ) - - self.message._body_gen.append(message) # pylint: disable=protected-access + self._amqp_transport.add_batch(self, outgoing_sb_message) # pylint: disable=protected-access self._size = size_after_add self._count += 1 - self._messages.append(message) + self._messages.append(outgoing_sb_message) + + @property + def message(self) -> Union["BatchMessage", LegacyBatchMessage]: + """DEPRECATED: Get the underlying uamqp.BatchMessage or LegacyBatchMessage. + This is deprecated and will be removed in a later release. + + :rtype: uamqp.BatchMessage or LegacyBatchMessage + """ + warnings.warn( + "The `message` property is deprecated and will be removed in future versions.", + DeprecationWarning, + ) + if not self._uamqp_message: + if self._amqp_transport.KIND == "pyamqp": + message = AmqpAnnotatedMessage(message=pyamqp_Message(*self._message)) + self._uamqp_message = LegacyBatchMessage( + message, + to_outgoing_amqp_message=PyamqpTransport.to_outgoing_amqp_message, + ) + else: + self._uamqp_message = self._message + return self._uamqp_message + + @message.setter + def message(self, value: "BatchMessage") -> None: + """DEPRECATED: Set the underlying BatchMessage. + This is deprecated and will be removed in a later release. + """ + warnings.warn( + "The `message` property is deprecated and will be removed in future versions.", + DeprecationWarning, + ) + self._uamqp_message = value @property - def max_size_in_bytes(self): - # type: () -> int + def max_size_in_bytes(self) -> int: """The maximum size of bytes data that a ServiceBusMessageBatch object can hold. :rtype: int @@ -738,16 +754,14 @@ def max_size_in_bytes(self): return self._max_size_in_bytes @property - def size_in_bytes(self): - # type: () -> int + def size_in_bytes(self) -> int: """The combined size of the messages in the batch, in bytes. :rtype: int """ return self._size - def add_message(self, message): - # type: (Union[ServiceBusMessage, AmqpAnnotatedMessage, Mapping[str, Any]]) -> None + def add_message(self, message: Union[ServiceBusMessage, AmqpAnnotatedMessage, Mapping[str, Any]]) -> None: """Try to add a single Message to the batch. The total size of an added message is the sum of its body, properties, etc. @@ -763,7 +777,7 @@ def add_message(self, message): return self._add(message) -class ServiceBusReceivedMessage(ServiceBusMessage): +class ServiceBusReceivedMessage(ServiceBusMessage): # pylint: disable=too-many-instance-attributes """ A Service Bus Message received from service side. @@ -781,30 +795,48 @@ class ServiceBusReceivedMessage(ServiceBusMessage): """ - def __init__(self, message, receive_mode=ServiceBusReceiveMode.PEEK_LOCK, **kwargs): - # type: (uamqp.message.Message, Union[ServiceBusReceiveMode, str], Any) -> None + def __init__( + self, + message: Union["Message", "pyamqp_Message"], + receive_mode: Union[ServiceBusReceiveMode, str] = ServiceBusReceiveMode.PEEK_LOCK, + frame: Optional["TransferFrame"] = None, + **kwargs + ) -> None: + self._amqp_transport = kwargs.pop("amqp_transport", PyamqpTransport) super(ServiceBusReceivedMessage, self).__init__(None, message=message) # type: ignore + if self._amqp_transport.KIND == "uamqp": + self._uamqp_message = message + self._message = message self._settled = receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE + self._delivery_tag = self._amqp_transport.get_message_delivery_tag(message, frame) + self._delivery_id = self._amqp_transport.get_message_delivery_id(message, frame) # only used by pyamqp self._received_timestamp_utc = utc_now() self._is_deferred_message = kwargs.get("is_deferred_message", False) self._is_peeked_message = kwargs.get("is_peeked_message", False) - self.auto_renew_error = None # type: Optional[Exception] + self.auto_renew_error: Optional[Exception]= None try: - self._receiver = kwargs.pop( + self._receiver: Union["ServiceBusReceiver", "AsyncServiceBusReceiver"] = kwargs.pop( "receiver" - ) # type: Union[ServiceBusReceiver, AsyncServiceBusReceiver] + ) except KeyError: raise TypeError( "ServiceBusReceivedMessage requires a receiver to be initialized. " + "This class should never be initialized by a user; " + "for outgoing messages, the ServiceBusMessage class should be utilized instead." ) - self._expiry = None # type: Optional[datetime.datetime] + self._expiry: Optional[datetime.datetime] = None + + def __getstate__(self): + state = self.__dict__.copy() + state['_receiver'] = None + state['_uamqp_message'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) @property - def _lock_expired(self): - # type: () -> bool - # pylint: disable=protected-access + def _lock_expired(self) -> bool: """ Whether the lock on the message has expired. @@ -821,13 +853,7 @@ def _lock_expired(self): return True return False - def _to_outgoing_message(self): - # type: () -> ServiceBusMessage - # pylint: disable=protected-access - return ServiceBusMessage(body=None, message=self.raw_amqp_message._to_outgoing_amqp_message()) - - def __repr__(self): # pylint: disable=too-many-branches,too-many-statements - # type: () -> str + def __repr__(self) -> str: # pylint: disable=too-many-branches,too-many-statements # pylint: disable=bare-except message_repr = "body={}".format( str(self) @@ -926,9 +952,34 @@ def __repr__(self): # pylint: disable=too-many-branches,too-many-statements message_repr += ", locked_until_utc=" return "ServiceBusReceivedMessage({})".format(message_repr)[:1024] + @property # type: ignore[misc] # TODO: ignoring error to copy over setter, since it's inherited + def message(self) -> Union["Message", LegacyMessage]: + """DEPRECATED: Get the underlying LegacyMessage. + This is deprecated and will be removed in a later release. + + :rtype: LegacyMessage + """ + warnings.warn( + "The `message` property is deprecated and will be removed in future versions.", + DeprecationWarning, + ) + if not self._uamqp_message: + if not self._settled: + settler = self._receiver._handler # pylint:disable=protected-access + else: + settler = None + self._uamqp_message = LegacyMessage( + self._raw_amqp_message, + delivery_no=self._delivery_id, + delivery_tag=self._delivery_tag, + settler=settler, + encoding=self._encoding, + to_outgoing_amqp_message=PyamqpTransport.to_outgoing_amqp_message + ) + return self._uamqp_message + @property - def dead_letter_error_description(self): - # type: () -> Optional[str] + def dead_letter_error_description(self) -> Optional[str]: """ Dead letter error description, when the message is received from a deadletter subqueue of an entity. @@ -944,8 +995,7 @@ def dead_letter_error_description(self): return None @property - def dead_letter_reason(self): - # type: () -> Optional[str] + def dead_letter_reason(self) -> Optional[str]: """ Dead letter reason, when the message is received from a deadletter subqueue of an entity. @@ -961,8 +1011,7 @@ def dead_letter_reason(self): return None @property - def dead_letter_source(self): - # type: () -> Optional[str] + def dead_letter_source(self) -> Optional[str]: """ The name of the queue or subscription that this message was enqueued on, before it was deadlettered. This property is only set in messages that have been dead-lettered and subsequently auto-forwarded @@ -980,8 +1029,7 @@ def dead_letter_source(self): return None @property - def state(self): - # type: () -> ServiceBusMessageState + def state(self) -> ServiceBusMessageState: """ Defaults to Active. Represents the message state of the message. Can be Active, Deferred. or Scheduled. @@ -998,8 +1046,7 @@ def state(self): return ServiceBusMessageState.ACTIVE @property - def delivery_count(self): - # type: () -> Optional[int] + def delivery_count(self) -> Optional[int]: """ Number of deliveries that have been attempted for this message. The count is incremented when a message lock expires or the message is explicitly abandoned by the receiver. @@ -1011,8 +1058,7 @@ def delivery_count(self): return None @property - def enqueued_sequence_number(self): - # type: () -> Optional[int] + def enqueued_sequence_number(self) -> Optional[int]: """ For messages that have been auto-forwarded, this property reflects the sequence number that had first been assigned to the message at its original point of submission. @@ -1024,8 +1070,7 @@ def enqueued_sequence_number(self): return None @property - def enqueued_time_utc(self): - # type: () -> Optional[datetime.datetime] + def enqueued_time_utc(self) -> Optional[datetime.datetime]: """ The UTC datetime at which the message has been accepted and stored in the entity. @@ -1039,8 +1084,7 @@ def enqueued_time_utc(self): return None @property - def expires_at_utc(self): - # type: () -> Optional[datetime.datetime] + def expires_at_utc(self) -> Optional[datetime.datetime]: """ The UTC datetime at which the message is marked for removal and no longer available for retrieval from the entity due to expiration. Expiry is controlled by the `Message.time_to_live` property. @@ -1053,8 +1097,7 @@ def expires_at_utc(self): return None @property - def sequence_number(self): - # type: () -> Optional[int] + def sequence_number(self) -> Optional[int]: """ The unique number assigned to a message by Service Bus. The sequence number is a unique 64-bit integer assigned to a message as it is accepted and stored by the broker and functions as its true identifier. @@ -1068,8 +1111,7 @@ def sequence_number(self): return None @property - def lock_token(self): - # type: () -> Optional[Union[uuid.UUID, str]] + def lock_token(self) -> Optional[Union[uuid.UUID, str]]: """ The lock token for the current message serving as a reference to the lock that is being held by the broker in PEEK_LOCK mode. @@ -1079,8 +1121,8 @@ def lock_token(self): if self._settled: return None - if self.message.delivery_tag: - return uuid.UUID(bytes_le=self.message.delivery_tag) + if self._delivery_tag: + return uuid.UUID(bytes_le=self._delivery_tag) delivery_annotations = self._raw_amqp_message.delivery_annotations if delivery_annotations: @@ -1088,12 +1130,10 @@ def lock_token(self): return None @property - def locked_until_utc(self): - # type: () -> Optional[datetime.datetime] - # pylint: disable=protected-access + def locked_until_utc(self) -> Optional[datetime.datetime]: """ The UTC datetime until which the message will be locked in the queue/subscription. - When the lock expires, delivery count of hte message is incremented and the message + When the lock expires, delivery count of the message is incremented and the message is again available for retrieval. :rtype: datetime.datetime diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/mgmt_handlers.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/mgmt_handlers.py index 660382b9839da..df56fdba3bb2a 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/mgmt_handlers.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/mgmt_handlers.py @@ -5,95 +5,92 @@ # ------------------------------------------------------------------------- import logging -import uamqp + from .message import ServiceBusReceivedMessage -from ..exceptions import _handle_amqp_mgmt_error from .constants import ServiceBusReceiveMode, MGMT_RESPONSE_MESSAGE_ERROR_CONDITION _LOGGER = logging.getLogger(__name__) def default( # pylint: disable=inconsistent-return-statements - status_code, message, description + status_code, message, description, amqp_transport ): condition = message.application_properties.get( MGMT_RESPONSE_MESSAGE_ERROR_CONDITION ) if status_code == 200: - return message.get_data() + return message.value - _handle_amqp_mgmt_error( + amqp_transport.handle_amqp_mgmt_error( # pylint: disable=protected-access _LOGGER, "Service request failed.", condition, description, status_code ) def session_lock_renew_op( # pylint: disable=inconsistent-return-statements - status_code, message, description + status_code, message, description, amqp_transport ): condition = message.application_properties.get( MGMT_RESPONSE_MESSAGE_ERROR_CONDITION ) if status_code == 200: - return message.get_data() + return message.value - _handle_amqp_mgmt_error( + amqp_transport.handle_amqp_mgmt_error( # pylint: disable=protected-access _LOGGER, "Session lock renew failed.", condition, description, status_code ) def message_lock_renew_op( # pylint: disable=inconsistent-return-statements - status_code, message, description + status_code, message, description, amqp_transport ): condition = message.application_properties.get( MGMT_RESPONSE_MESSAGE_ERROR_CONDITION ) if status_code == 200: - return message.get_data() + # TODO: will this always be body type ValueType? + return message.value - _handle_amqp_mgmt_error( + amqp_transport.handle_amqp_mgmt_error( # pylint: disable=protected-access _LOGGER, "Message lock renew failed.", condition, description, status_code ) def peek_op( # pylint: disable=inconsistent-return-statements - status_code, message, description, receiver + status_code, message, description, receiver, amqp_transport ): condition = message.application_properties.get( MGMT_RESPONSE_MESSAGE_ERROR_CONDITION ) if status_code == 200: - parsed = [] - for m in message.get_data()[b"messages"]: - wrapped = uamqp.Message.decode_from_bytes(bytearray(m[b"message"])) - parsed.append( - ServiceBusReceivedMessage( - wrapped, is_peeked_message=True, receiver=receiver - ) + return amqp_transport.parse_received_message( + message, + message_type=ServiceBusReceivedMessage, + receiver=receiver, + is_peeked_message=True ) - return parsed if status_code in [202, 204]: return [] - _handle_amqp_mgmt_error( + amqp_transport.handle_amqp_mgmt_error( # pylint: disable=protected-access _LOGGER, "Message peek failed.", condition, description, status_code ) def list_sessions_op( # pylint: disable=inconsistent-return-statements - status_code, message, description + status_code, message, description, amqp_transport ): condition = message.application_properties.get( MGMT_RESPONSE_MESSAGE_ERROR_CONDITION ) if status_code == 200: parsed = [] - for m in message.get_data()[b"sessions-ids"]: + for m in amqp_transport.get_message_value(message)[b"sessions-ids"]: parsed.append(m.decode("UTF-8")) return parsed if status_code in [202, 204]: return [] - _handle_amqp_mgmt_error( + amqp_transport.handle_amqp_mgmt_error( _LOGGER, "List sessions failed.", condition, description, status_code ) @@ -103,6 +100,7 @@ def deferred_message_op( # pylint: disable=inconsistent-return-statements message, description, receiver, + amqp_transport, receive_mode=ServiceBusReceiveMode.PEEK_LOCK, message_type=ServiceBusReceivedMessage, ): @@ -110,19 +108,17 @@ def deferred_message_op( # pylint: disable=inconsistent-return-statements MGMT_RESPONSE_MESSAGE_ERROR_CONDITION ) if status_code == 200: - parsed = [] - for m in message.get_data()[b"messages"]: - wrapped = uamqp.Message.decode_from_bytes(bytearray(m[b"message"])) - parsed.append( - message_type( - wrapped, receive_mode, is_deferred_message=True, receiver=receiver - ) + return amqp_transport.parse_received_message( + message, + message_type=message_type, + receiver=receiver, + receive_mode=receive_mode, + is_deferred_message=True ) - return parsed if status_code in [202, 204]: return [] - _handle_amqp_mgmt_error( + amqp_transport.handle_amqp_mgmt_error( _LOGGER, "Retrieving deferred messages failed.", condition, @@ -132,14 +128,14 @@ def deferred_message_op( # pylint: disable=inconsistent-return-statements def schedule_op( # pylint: disable=inconsistent-return-statements - status_code, message, description + status_code, message, description, amqp_transport ): condition = message.application_properties.get( MGMT_RESPONSE_MESSAGE_ERROR_CONDITION ) if status_code == 200: - return message.get_data()[b"sequence-numbers"] + return message.value[b"sequence-numbers"] - _handle_amqp_mgmt_error( + amqp_transport.handle_amqp_mgmt_error( _LOGGER, "Scheduling messages failed.", condition, description, status_code ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/receiver_mixins.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/receiver_mixins.py index 9c0962e4ce88f..670a0c5de18ef 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/receiver_mixins.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/receiver_mixins.py @@ -4,33 +4,22 @@ # license information. # ------------------------------------------------------------------------- import uuid -import functools -from typing import Optional, Callable +from typing import TYPE_CHECKING, Union -from uamqp import Source - -from .message import ServiceBusReceivedMessage +from ..exceptions import MessageAlreadySettled from .constants import ( NEXT_AVAILABLE_SESSION, - SESSION_FILTER, - SESSION_LOCKED_UNTIL, - DATETIMEOFFSET_EPOCH, MGMT_REQUEST_SESSION_ID, ServiceBusReceiveMode, - DEADLETTERNAME, - RECEIVER_LINK_DEAD_LETTER_REASON, - RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION, - MESSAGE_COMPLETE, - MESSAGE_DEAD_LETTER, - MESSAGE_ABANDON, - MESSAGE_DEFER, ) -from ..exceptions import _ServiceBusErrorPolicy, MessageAlreadySettled -from .utils import utc_from_timestamp, utc_now +if TYPE_CHECKING: + from .._transport._base import AmqpTransport + from ..aio._transport._base_async import AmqpTransportAsync class ReceiverMixin(object): # pylint: disable=too-many-instance-attributes def _populate_attributes(self, **kwargs): + self._amqp_transport: Union["AmqpTransport", "AmqpTransportAsync"] if kwargs.get("subscription_name"): self._subscription_name = kwargs.get("subscription_name") self._is_subscription = True @@ -51,8 +40,10 @@ def _populate_attributes(self, **kwargs): ) self._session_id = kwargs.get("session_id") - self._error_policy = _ServiceBusErrorPolicy( - max_retries=self._config.retry_total, is_session=bool(self._session_id) + + self._error_policy = self._amqp_transport.create_retry_policy( + config=self._config, + is_session=bool(self._session_id) ) self._name = kwargs.get("client_identifier", "SBReceiver-{}".format(uuid.uuid4())) @@ -68,7 +59,7 @@ def _populate_attributes(self, **kwargs): # The relationship between the amount can be received and the time interval is linear: amount ~= perf * interval # In large max_message_count case, like 5000, the pull receive would always return hundreds of messages limited # by the perf and time. - self._further_pull_receive_timeout_ms = 200 + self._further_pull_receive_timeout = 0.2 * self._amqp_transport.TIMEOUT_FACTOR max_wait_time = kwargs.get("max_wait_time", None) if max_wait_time is not None and max_wait_time <= 0: raise ValueError("The max_wait_time must be greater than 0.") @@ -85,39 +76,24 @@ def _populate_attributes(self, **kwargs): "as they have been deleted, providing an AutoLockRenewer in this mode is invalid." ) - def _build_message(self, received, message_type=ServiceBusReceivedMessage): - message = message_type( - message=received, receive_mode=self._receive_mode, receiver=self - ) - self._last_received_sequenced_number = message.sequence_number - return message - def _get_source(self): # pylint: disable=protected-access if self._session: - source = Source(self._entity_uri) - session_filter = ( - None if self._session_id == NEXT_AVAILABLE_SESSION else self._session_id - ) - source.set_filter(session_filter, name=SESSION_FILTER, descriptor=None) - return source + session_filter = None if self._session_id == NEXT_AVAILABLE_SESSION else self._session_id + return self._amqp_transport.create_source(self._entity_uri, session_filter) return self._entity_uri def _check_message_alive(self, message, action): # pylint: disable=no-member, protected-access if message._is_peeked_message: raise ValueError( - "The operation {} is not supported for peeked messages." - "Only messages received using receive methods in PEEK_LOCK mode can be settled.".format( - action - ) + f"The operation {action} is not supported for peeked messages." + "Only messages received using receive methods in PEEK_LOCK mode can be settled." ) if self._receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE: raise ValueError( - "The operation {} is not supported in 'RECEIVE_AND_DELETE' receive mode.".format( - action - ) + f"The operation {action} is not supported in 'RECEIVE_AND_DELETE' receive mode." ) if message._settled: @@ -125,62 +101,10 @@ def _check_message_alive(self, message, action): if not self._running: raise ValueError( - "Failed to {} the message as the handler has already been shutdown." - "Please use ServiceBusClient to create a new instance.".format(action) - ) - - def _settle_message_via_receiver_link( - self, - message, - settle_operation, - dead_letter_reason=None, - dead_letter_error_description=None, - ): - # type: (ServiceBusReceivedMessage, str, Optional[str], Optional[str]) -> Callable - # pylint: disable=no-self-use - if settle_operation == MESSAGE_COMPLETE: - return functools.partial(message.message.accept) - if settle_operation == MESSAGE_ABANDON: - return functools.partial(message.message.modify, True, False) - if settle_operation == MESSAGE_DEAD_LETTER: - return functools.partial( - message.message.reject, - condition=DEADLETTERNAME, - description=dead_letter_error_description, - info={ - RECEIVER_LINK_DEAD_LETTER_REASON: dead_letter_reason, - RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION: dead_letter_error_description, - }, + f"Failed to {action} the message as the handler has already been shutdown." + "Please use ServiceBusClient to create a new instance." ) - if settle_operation == MESSAGE_DEFER: - return functools.partial(message.message.modify, True, True) - raise ValueError( - "Unsupported settle operation type: {}".format(settle_operation) - ) - - def _on_attach(self, source, target, properties, error): - # pylint: disable=protected-access, unused-argument - if self._session and str(source) == self._entity_uri: - # This has to live on the session object so that autorenew has access to it. - self._session._session_start = utc_now() - expiry_in_seconds = properties.get(SESSION_LOCKED_UNTIL) - if expiry_in_seconds: - expiry_in_seconds = ( - expiry_in_seconds - DATETIMEOFFSET_EPOCH - ) / 10000000 - self._session._locked_until_utc = utc_from_timestamp(expiry_in_seconds) - session_filter = source.get_filter(name=SESSION_FILTER) - self._session_id = session_filter.decode(self._config.encoding) - self._session._session_id = self._session_id def _populate_message_properties(self, message): if self._session: message[MGMT_REQUEST_SESSION_ID] = self._session_id - - def _enhanced_message_received(self, message): - # pylint: disable=protected-access - self._handler._was_message_received = True - if self._receive_context.is_set(): - self._handler._received_messages.put(message) - else: - message.release() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py new file mode 100644 index 0000000000000..142be9ffdaa03 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py @@ -0,0 +1,302 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from __future__ import annotations +from contextlib import contextmanager +from enum import Enum +import logging +from typing import ( + Dict, + Iterable, + Iterator, + List, + Optional, + Type, + TYPE_CHECKING, + Union, + cast, +) + +from azure.core import CaseInsensitiveEnumMeta +from azure.core.settings import settings +from azure.core.tracing import SpanKind, Link + +if TYPE_CHECKING: + try: + # pylint:disable=unused-import + from uamqp import Message as uamqp_Message + except ImportError: + uamqp_Message = None + from azure.core.tracing import AbstractSpan + + from .._pyamqp.message import Message as pyamqp_Message + from .message import ( + ServiceBusReceivedMessage, + ServiceBusMessage, + ServiceBusMessageBatch + ) + from .._base_handler import BaseHandler + from ..aio._base_handler_async import BaseHandler as BaseHandlerAsync + from .._servicebus_receiver import ServiceBusReceiver + from ..aio._servicebus_receiver_async import ServiceBusReceiver as ServiceBusReceiverAsync + from .._servicebus_sender import ServiceBusSender + from ..aio._servicebus_sender_async import ServiceBusSender as ServiceBusSenderAsync + from .._transport._base import AmqpTransport + from ..aio._transport._base_async import AmqpTransportAsync + + ReceiveMessageTypes = Union[ + ServiceBusReceivedMessage, + pyamqp_Message, + uamqp_Message + ] + +TRACE_DIAGNOSTIC_ID_PROPERTY = b"Diagnostic-Id" +TRACE_ENQUEUED_TIME_PROPERTY = b"x-opt-enqueued-time" +TRACE_PARENT_PROPERTY = b"traceparent" +TRACE_STATE_PROPERTY = b"tracestate" +TRACE_PROPERTY_ENCODING = "ascii" + +SPAN_ENQUEUED_TIME_PROPERTY = "enqueuedTime" + +SPAN_NAME_RECEIVE = "ServiceBus.receive" +SPAN_NAME_RECEIVE_DEFERRED = "ServiceBus.receive_deferred" +SPAN_NAME_PEEK = "ServiceBus.peek" +SPAN_NAME_SEND = "ServiceBus.send" +SPAN_NAME_SCHEDULE = "ServiceBus.schedule" +SPAN_NAME_MESSAGE = "ServiceBus.message" + + +_LOGGER = logging.getLogger(__name__) + + +class TraceAttributes: + TRACE_NAMESPACE_ATTRIBUTE = "az.namespace" + TRACE_NAMESPACE = "Microsoft.ServiceBus" + + TRACE_MESSAGING_SYSTEM_ATTRIBUTE = "messaging.system" + TRACE_MESSAGING_SYSTEM = "servicebus" + + TRACE_NET_PEER_NAME_ATTRIBUTE = "net.peer.name" + TRACE_MESSAGING_DESTINATION_ATTRIBUTE = "messaging.destination.name" + TRACE_MESSAGING_OPERATION_ATTRIBUTE = "messaging.operation" + TRACE_MESSAGING_BATCH_COUNT_ATTRIBUTE = "messaging.batch.message_count" + + LEGACY_TRACE_COMPONENT_ATTRIBUTE = "component" + LEGACY_TRACE_MESSAGE_BUS_DESTINATION_ATTRIBUTE = "message_bus.destination" + LEGACY_TRACE_PEER_ADDRESS_ATTRIBUTE = "peer.address" + + +class TraceOperationTypes(str, Enum, metaclass=CaseInsensitiveEnumMeta): + PUBLISH = "publish" + RECEIVE = "receive" + SETTLE = "settle" + + +def is_tracing_enabled(): + span_impl_type = settings.tracing_implementation() + return span_impl_type is not None + + +@contextmanager +def send_trace_context_manager( + sender: Union[ServiceBusSender, ServiceBusSenderAsync], + span_name: str = SPAN_NAME_SEND, + links: Optional[List[Link]] = None +) -> Iterator[None]: + """Tracing for sending messages.""" + span_impl_type: Type[AbstractSpan] = settings.tracing_implementation() + + if span_impl_type is not None: + links = links or [] + with span_impl_type(name=span_name, kind=SpanKind.CLIENT, links=links) as span: + add_span_attributes(span, TraceOperationTypes.PUBLISH, sender, message_count=len(links)) + yield + else: + yield + + +@contextmanager +def receive_trace_context_manager( + receiver: Union[ServiceBusReceiver, ServiceBusReceiverAsync], + span_name: str = SPAN_NAME_RECEIVE, + links: Optional[List[Link]] = None, + start_time: Optional[int] = None +) -> Iterator[None]: + """Tracing for receiving messages.""" + span_impl_type: Type[AbstractSpan] = settings.tracing_implementation() + if span_impl_type is not None: + links = links or [] + with span_impl_type(name=span_name, kind=SpanKind.CLIENT, links=links, start_time=start_time) as span: + add_span_attributes(span, TraceOperationTypes.RECEIVE, receiver, message_count=len(links)) + yield + else: + yield + + +@contextmanager +def settle_trace_context_manager( + receiver: Union[ServiceBusReceiver, ServiceBusReceiverAsync], + operation: str, + links: Optional[List[Link]] = None +): + """Tracing for settling messages.""" + span_impl_type = settings.tracing_implementation() + if span_impl_type is not None: + links = links or [] + with span_impl_type(name=f"ServiceBus.{operation}", kind=SpanKind.CLIENT, links=links) as span: + add_span_attributes(span, TraceOperationTypes.SETTLE, receiver) + yield + else: + yield + + +def trace_message( + message: Union[uamqp_Message, pyamqp_Message], + amqp_transport: Union[AmqpTransport, AmqpTransportAsync], + additional_attributes: Optional[Dict[str, Union[str, int]]] = None +) -> Union["uamqp_Message", "pyamqp_Message"]: + """Adds tracing information to the message and returns the updated message. + + Will open and close a message span, and add tracing context to the app properties of the message. + """ + try: + span_impl_type: Type[AbstractSpan] = settings.tracing_implementation() + if span_impl_type is not None: + with span_impl_type(name=SPAN_NAME_MESSAGE, kind=SpanKind.PRODUCER) as message_span: + headers = message_span.to_header() + + if "traceparent" in headers: + message = amqp_transport.update_message_app_properties( + message, + TRACE_DIAGNOSTIC_ID_PROPERTY, + headers["traceparent"] + ) + message = amqp_transport.update_message_app_properties( + message, + TRACE_PARENT_PROPERTY, + headers["traceparent"] + ) + + if "tracestate" in headers: + message = amqp_transport.update_message_app_properties( + message, + TRACE_STATE_PROPERTY, + headers["tracestate"] + ) + + message_span.add_attribute( + TraceAttributes.TRACE_NAMESPACE_ATTRIBUTE, TraceAttributes.TRACE_NAMESPACE + ) + message_span.add_attribute( + TraceAttributes.TRACE_MESSAGING_SYSTEM_ATTRIBUTE, TraceAttributes.TRACE_MESSAGING_SYSTEM + ) + + if additional_attributes: + for key, value in additional_attributes.items(): + if value is not None: + message_span.add_attribute(key, value) + + except Exception as exp: # pylint:disable=broad-except + _LOGGER.warning("trace_message had an exception %r", exp) + + return message + + +def get_receive_links(messages: Union[ReceiveMessageTypes, Iterable[ReceiveMessageTypes]]) -> List[Link]: + if not is_tracing_enabled(): + return [] + + trace_messages = ( + messages if isinstance(messages, Iterable) # pylint:disable=isinstance-second-argument-not-valid-type + else (messages,) + ) + + links = [] + try: + for message in trace_messages: + if message.application_properties: + headers = {} + + traceparent = message.application_properties.get(TRACE_PARENT_PROPERTY, b"") + if hasattr(traceparent, "decode"): + traceparent = traceparent.decode(TRACE_PROPERTY_ENCODING) + if traceparent: + headers["traceparent"] = cast(str, traceparent) + + tracestate = message.application_properties.get(TRACE_STATE_PROPERTY, b"") + if hasattr(tracestate, "decode"): + tracestate = tracestate.decode(TRACE_PROPERTY_ENCODING) + if tracestate: + headers["tracestate"] = cast(str, tracestate) + + enqueued_time = message.raw_amqp_message.annotations.get(TRACE_ENQUEUED_TIME_PROPERTY) + attributes = {SPAN_ENQUEUED_TIME_PROPERTY: enqueued_time} if enqueued_time else None + + if headers: + links.append(Link(headers, attributes=attributes)) + except AttributeError: + pass + return links + + +def get_span_links_from_batch(batch: ServiceBusMessageBatch) -> List[Link]: + """Create span links from a batch of messages.""" + links = [] + for message in batch._messages: # pylint: disable=protected-access + link = get_span_link_from_message(message._message) # pylint: disable=protected-access + if link: + links.append(link) + return links + + +def get_span_link_from_message(message: Union[uamqp_Message, pyamqp_Message, ServiceBusMessage]) -> Optional[Link]: + """Create a span link from a message. + + This will extract the traceparent and tracestate from the message application properties and create span links + based on these values. + """ + headers = {} + try: + if message.application_properties: + traceparent = message.application_properties.get(TRACE_PARENT_PROPERTY, b"") + if hasattr(traceparent, "decode"): + traceparent = traceparent.decode(TRACE_PROPERTY_ENCODING) + if traceparent: + headers["traceparent"] = cast(str, traceparent) + + tracestate = message.application_properties.get(TRACE_STATE_PROPERTY, b"") + if hasattr(tracestate, "decode"): + tracestate = tracestate.decode(TRACE_PROPERTY_ENCODING) + if tracestate: + headers["tracestate"] = cast(str, tracestate) + except AttributeError : + return None + return Link(headers) if headers else None + + +def add_span_attributes( + span: AbstractSpan, + operation_type: TraceOperationTypes, + handler: Union[BaseHandler, BaseHandlerAsync], + message_count: int = 0 +) -> None: + """Add attributes to span based on the operation type.""" + + span.add_attribute(TraceAttributes.TRACE_NAMESPACE_ATTRIBUTE, TraceAttributes.TRACE_NAMESPACE) + span.add_attribute(TraceAttributes.TRACE_MESSAGING_SYSTEM_ATTRIBUTE, TraceAttributes.TRACE_MESSAGING_SYSTEM) + span.add_attribute(TraceAttributes.TRACE_MESSAGING_OPERATION_ATTRIBUTE, operation_type) + + if message_count > 1: + span.add_attribute(TraceAttributes.TRACE_MESSAGING_BATCH_COUNT_ATTRIBUTE, message_count) + + if operation_type in (TraceOperationTypes.PUBLISH, TraceOperationTypes.RECEIVE): + # Maintain legacy attributes for backwards compatibility. + span.add_attribute(TraceAttributes.LEGACY_TRACE_COMPONENT_ATTRIBUTE, TraceAttributes.TRACE_MESSAGING_SYSTEM) + span.add_attribute(TraceAttributes.LEGACY_TRACE_MESSAGE_BUS_DESTINATION_ATTRIBUTE, handler._entity_name) # pylint: disable=protected-access + span.add_attribute(TraceAttributes.LEGACY_TRACE_PEER_ADDRESS_ATTRIBUTE, handler.fully_qualified_namespace) + + elif operation_type == TraceOperationTypes.SETTLE: + span.add_attribute(TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE, handler.fully_qualified_namespace) + span.add_attribute(TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE, handler._entity_name) # pylint: disable=protected-access diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py index 1fd365e9b9a3d..7753ea849e4ca 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- - import sys import datetime import logging @@ -12,8 +11,6 @@ from typing import ( Any, Dict, - Iterable, - Iterator, List, Mapping, Optional, @@ -21,9 +18,9 @@ TYPE_CHECKING, Union, Tuple, - cast + cast, + Callable ) -from contextlib import contextmanager from datetime import timezone try: @@ -31,11 +28,6 @@ except ImportError: from urllib.parse import urlparse -from uamqp import authentication, types - -from azure.core.settings import settings -from azure.core.tracing import SpanKind, Link - from .._version import VERSION from .constants import ( JWT_TOKEN_SCOPE, @@ -44,27 +36,23 @@ DEAD_LETTER_QUEUE_SUFFIX, TRANSFER_DEAD_LETTER_QUEUE_SUFFIX, USER_AGENT_PREFIX, - SPAN_NAME_SEND, - SPAN_NAME_MESSAGE, - TRACE_PARENT_PROPERTY, - TRACE_NAMESPACE, - TRACE_NAMESPACE_PROPERTY, - TRACE_PROPERTY_ENCODING, - TRACE_ENQUEUED_TIME_PROPERTY, - SPAN_ENQUEUED_TIME_PROPERTY, - SPAN_NAME_RECEIVE, ) from ..amqp import AmqpAnnotatedMessage if TYPE_CHECKING: - from .message import ( - ServiceBusReceivedMessage, - ServiceBusMessage, - ) - from azure.core.tracing import AbstractSpan + try: + # pylint:disable=unused-import + from uamqp import ( + types as uamqp_types + ) + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + except ImportError: + pass + from .._pyamqp.authentication import JWTTokenAuth as pyamqp_JWTTokenAuth + from .message import ServiceBusReceivedMessage, ServiceBusMessage from azure.core.credentials import AzureSasCredential - from .receiver_mixins import ReceiverMixin from .._servicebus_session import BaseSession + from .._transport._base import AmqpTransport MessagesType = Union[ Mapping[str, Any], @@ -98,8 +86,9 @@ def build_uri(address, entity): return address -def create_properties(user_agent=None): - # type: (Optional[str]) -> Dict[types.AMQPSymbol, str] +def create_properties( + user_agent: Optional[str] = None, *, amqp_transport: "AmqpTransport" +) -> Union[Dict["uamqp_types.AMQPSymbol", str], Dict[str, str]]: """ Format the properties with which to instantiate the connection. This acts like a user agent over HTTP. @@ -109,23 +98,22 @@ def create_properties(user_agent=None): :rtype: dict """ - properties = {} - properties[types.AMQPSymbol("product")] = USER_AGENT_PREFIX - properties[types.AMQPSymbol("version")] = VERSION - framework = "Python/{}.{}.{}".format( - sys.version_info[0], sys.version_info[1], sys.version_info[2] - ) - properties[types.AMQPSymbol("framework")] = framework + properties: Dict[Any, str] = {} + properties[amqp_transport.PRODUCT_SYMBOL] = USER_AGENT_PREFIX + properties[amqp_transport.VERSION_SYMBOL] = VERSION + framework = f"Python/{sys.version_info[0]}.{sys.version_info[1]}.{sys.version_info[2]}" + properties[amqp_transport.FRAMEWORK_SYMBOL] = framework platform_str = platform.platform() - properties[types.AMQPSymbol("platform")] = platform_str + properties[amqp_transport.PLATFORM_SYMBOL] = platform_str - final_user_agent = "{}/{} {} ({})".format( - USER_AGENT_PREFIX, VERSION, framework, platform_str + final_user_agent = ( + f"{USER_AGENT_PREFIX}/{VERSION} {amqp_transport.TRANSPORT_IDENTIFIER} " + f"{framework} ({platform_str})" ) if user_agent: - final_user_agent = "{} {}".format(user_agent, final_user_agent) + final_user_agent = f"{user_agent} {final_user_agent}" - properties[types.AMQPSymbol("user-agent")] = final_user_agent + properties[amqp_transport.USER_AGENT_SYMBOL] = final_user_agent return properties @@ -143,8 +131,9 @@ def get_renewable_start_time(renewable): ) -def get_renewable_lock_duration(renewable): - # type: (Union[ServiceBusReceivedMessage, BaseSession]) -> datetime.timedelta +def get_renewable_lock_duration( + renewable: Union["ServiceBusReceivedMessage", "BaseSession"] +) -> datetime.timedelta: # pylint: disable=protected-access try: return max( @@ -157,7 +146,7 @@ def get_renewable_lock_duration(renewable): ) -def create_authentication(client): +def create_authentication(client) -> Union["uamqp_JWTTokenAuth", "pyamqp_JWTTokenAuth"]: # pylint: disable=protected-access try: # ignore mypy's warning because token_type is Optional @@ -165,33 +154,20 @@ def create_authentication(client): except AttributeError: token_type = TOKEN_TYPE_JWT if token_type == TOKEN_TYPE_SASTOKEN: - auth = authentication.JWTTokenAuth( + return client._amqp_transport.create_token_auth( client._auth_uri, + get_token=functools.partial(client._credential.get_token, client._auth_uri), + token_type=token_type, + config=client._config, + update_token=True + ) + return client._amqp_transport.create_token_auth( client._auth_uri, - functools.partial(client._credential.get_token, client._auth_uri), + get_token=functools.partial(client._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=client._config.auth_timeout, - http_proxy=client._config.http_proxy, - transport_type=client._config.transport_type, - custom_endpoint_hostname=client._config.custom_endpoint_hostname, - port=client._config.connection_port, - verify=client._config.connection_verify + config=client._config, + update_token=False, ) - auth.update_token() - return auth - return authentication.JWTTokenAuth( - client._auth_uri, - client._auth_uri, - functools.partial(client._credential.get_token, JWT_TOKEN_SCOPE), - token_type=token_type, - timeout=client._config.auth_timeout, - http_proxy=client._config.http_proxy, - transport_type=client._config.transport_type, - refresh_window=300, - custom_endpoint_hostname=client._config.custom_endpoint_hostname, - port=client._config.connection_port, - verify=client._config.connection_verify - ) def generate_dead_letter_entity_name( @@ -202,40 +178,52 @@ def generate_dead_letter_entity_name( if queue_name else (topic_name + "/Subscriptions/" + subscription_name) ) - entity_name = "{}{}".format( - entity_name, - TRANSFER_DEAD_LETTER_QUEUE_SUFFIX - if transfer_deadletter - else DEAD_LETTER_QUEUE_SUFFIX, + entity_name = ( + f"{entity_name}" + f"{TRANSFER_DEAD_LETTER_QUEUE_SUFFIX if transfer_deadletter else DEAD_LETTER_QUEUE_SUFFIX}" ) return entity_name -def _convert_to_single_service_bus_message(message, message_type): - # type: (SingleMessageType, Type[ServiceBusMessage]) -> ServiceBusMessage - # pylint: disable=protected-access +def _convert_to_single_service_bus_message( + message: "SingleMessageType", + message_type: Type["ServiceBusMessage"], + to_outgoing_amqp_message: Callable +) -> "ServiceBusMessage": try: # ServiceBusMessage/ServiceBusReceivedMessage - return message._to_outgoing_message() # type: ignore - except TypeError: - # AmqpAnnotatedMessage - return message._to_outgoing_message(message_type) # type: ignore + message = cast("ServiceBusMessage", message) + # pylint: disable=protected-access + message._message = to_outgoing_amqp_message(message.raw_amqp_message) + return message + except AttributeError: + # AmqpAnnotatedMessage or Mapping representation + pass + try: + message = cast(AmqpAnnotatedMessage, message) + amqp_message = to_outgoing_amqp_message(message) + return message_type(body=None, message=amqp_message, raw_amqp_message=message) except AttributeError: # Mapping representing pass - try: - return message_type(**cast(Mapping[str, Any], message))._to_outgoing_message() + # pylint: disable=protected-access + message = message_type(**cast(Mapping[str, Any], message)) + message._message = to_outgoing_amqp_message(message.raw_amqp_message) + return message except TypeError: raise TypeError( - "Only AmqpAnnotatedMessage, ServiceBusMessage instances or Mappings representing messages are supported. " - "Received instead: {}".format(message.__class__.__name__) + f"Only AmqpAnnotatedMessage, ServiceBusMessage instances or Mappings representing messages are supported. " + f"Received instead: {message.__class__.__name__}" ) -def transform_messages_if_needed(messages, message_type): - # type: (MessagesType, Type[ServiceBusMessage]) -> Union[ServiceBusMessage, List[ServiceBusMessage]] +def transform_outbound_messages( + messages: "MessagesType", + message_type: Type["ServiceBusMessage"], + to_outgoing_amqp_message: Callable +) -> Union["ServiceBusMessage", List["ServiceBusMessage"]]: """ This method serves multiple goals: 1. convert dict representations of one or more messages to @@ -250,13 +238,12 @@ def transform_messages_if_needed(messages, message_type): """ if isinstance(messages, list): return [ - _convert_to_single_service_bus_message(m, message_type) for m in messages + _convert_to_single_service_bus_message(m, message_type, to_outgoing_amqp_message) for m in messages ] - return _convert_to_single_service_bus_message(messages, message_type) + return _convert_to_single_service_bus_message(messages, message_type, to_outgoing_amqp_message) -def strip_protocol_from_uri(uri): - # type: (str) -> str +def strip_protocol_from_uri(uri: str) -> str: """Removes the protocol (e.g. http:// or sb://) from a URI, such as the FQDN.""" left_slash_pos = uri.find("//") if left_slash_pos != -1: @@ -264,89 +251,7 @@ def strip_protocol_from_uri(uri): return uri -@contextmanager -def send_trace_context_manager(span_name=SPAN_NAME_SEND): - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - - if span_impl_type is not None: - with span_impl_type(name=span_name, kind=SpanKind.CLIENT) as child: - yield child - else: - yield None - - -@contextmanager -def receive_trace_context_manager( - receiver: "ReceiverMixin", - span_name: str = SPAN_NAME_RECEIVE, - links: Optional[List[Link]] = None -) -> Iterator[None]: - """Tracing""" - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - if span_impl_type is None: - yield - else: - receive_span = span_impl_type(name=span_name, kind=SpanKind.CONSUMER, links=links) - receiver._add_span_request_attributes(receive_span) # type: ignore # pylint: disable=protected-access - - with receive_span: - yield - -def trace_message(message, parent_span=None): - # type: (ServiceBusMessage, Optional[AbstractSpan]) -> None - """Add tracing information to this message. - Will open and close a "Azure.Servicebus.message" span, and - add the "DiagnosticId" as app properties of the message. - """ - try: - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - if span_impl_type is not None: - current_span = parent_span or span_impl_type( - span_impl_type.get_current_span() - ) - link = Link({ - 'traceparent': current_span.get_trace_parent() - }) - with current_span.span(name=SPAN_NAME_MESSAGE, kind=SpanKind.PRODUCER, links=[link]) as message_span: - message_span.add_attribute(TRACE_NAMESPACE_PROPERTY, TRACE_NAMESPACE) - # TODO: Remove intermediary message; this is standin while this var is being renamed in a concurrent PR - if not message.message.application_properties: - message.message.application_properties = dict() - message.message.application_properties.setdefault( - TRACE_PARENT_PROPERTY, - message_span.get_trace_parent().encode(TRACE_PROPERTY_ENCODING), - ) - except Exception as exp: # pylint:disable=broad-except - _log.warning("trace_message had an exception %r", exp) - - -def get_receive_links(messages): - trace_messages = ( - messages if isinstance(messages, Iterable) # pylint:disable=isinstance-second-argument-not-valid-type - else (messages,) - ) - - links = [] - try: - for message in trace_messages: # type: ignore - if message.message.application_properties: - traceparent = message.message.application_properties.get( - TRACE_PARENT_PROPERTY, "" - ).decode(TRACE_PROPERTY_ENCODING) - if traceparent: - links.append(Link({'traceparent': traceparent}, - { - SPAN_ENQUEUED_TIME_PROPERTY: message.message.annotations.get( - TRACE_ENQUEUED_TIME_PROPERTY - ) - })) - except AttributeError: - pass - return links - - -def parse_sas_credential(credential): - # type: (AzureSasCredential) -> Tuple +def parse_sas_credential(credential: "AzureSasCredential") -> Tuple: sas = credential.signature parsed_sas = sas.split('&') expiry = None diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/__init__.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/__init__.py new file mode 100644 index 0000000000000..fc95444492660 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/__init__.py @@ -0,0 +1,21 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +__version__ = "2.0.0a1" + + +from ._connection import Connection +from ._transport import SSLTransport + +from .client import AMQPClient, ReceiveClient, SendClient + +__all__ = [ + "Connection", + "SSLTransport", + "AMQPClient", + "ReceiveClient", + "SendClient", +] diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_connection.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_connection.py new file mode 100644 index 0000000000000..46098abf7fbbe --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_connection.py @@ -0,0 +1,856 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import uuid +import logging +import time +from urllib.parse import urlparse +import socket +from ssl import SSLError +from typing import Any, Dict, Tuple, Optional, NamedTuple, Union, cast + +from ._transport import Transport +from .sasl import SASLTransport, SASLWithWebSocket +from .session import Session +from .performatives import OpenFrame, CloseFrame +from .constants import ( + PORT, + SECURE_PORT, + WEBSOCKET_PORT, + MAX_CHANNELS, + MAX_FRAME_SIZE_BYTES, + HEADER_FRAME, + ConnectionState, + EMPTY_FRAME, + TransportType, +) + +from .error import ErrorCondition, AMQPConnectionError, AMQPError + +_LOGGER = logging.getLogger(__name__) +_CLOSING_STATES = ( + ConnectionState.OC_PIPE, + ConnectionState.CLOSE_PIPE, + ConnectionState.DISCARDING, + ConnectionState.CLOSE_SENT, + ConnectionState.END, +) + + +def get_local_timeout(now, idle_timeout, last_frame_received_time): + # type: (float, float, float) -> bool + """Check whether the local timeout has been reached since a new incoming frame was received. + + :param float now: The current time to check against. + :rtype: bool + :returns: Whether to shutdown the connection due to timeout. + """ + if idle_timeout and last_frame_received_time: + time_since_last_received = now - last_frame_received_time + return time_since_last_received > idle_timeout + return False + + +class Connection(object): # pylint:disable=too-many-instance-attributes + """An AMQP Connection. + + :ivar str state: The connection state. + :param str endpoint: The endpoint to connect to. Must be fully qualified with scheme and port number. + :keyword str container_id: The ID of the source container. If not set a GUID will be generated. + :keyword int max_frame_size: Proposed maximum frame size in bytes. Default value is 64kb. + :keyword int channel_max: The maximum channel number that may be used on the Connection. Default value is 65535. + :keyword int idle_timeout: Connection idle time-out in seconds. + :keyword list(str) outgoing_locales: Locales available for outgoing text. + :keyword list(str) incoming_locales: Desired locales for incoming text in decreasing level of preference. + :keyword list(str) offered_capabilities: The extension capabilities the sender supports. + :keyword list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports + :keyword dict properties: Connection properties. + :keyword bool allow_pipelined_open: Allow frames to be sent on the connection before a response Open frame + has been received. Default value is `True`. + :keyword float idle_timeout_empty_frame_send_ratio: Portion of the idle timeout time to wait before sending an + empty frame. The default portion is 50% of the idle timeout value (i.e. `0.5`). + :keyword float idle_wait_time: The time in seconds to sleep while waiting for a response from the endpoint. + Default value is `0.1`. + :keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames + will be logged at the logging.INFO level. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. + Additionally the following keys may also be present: `'username', 'password'`. + """ + + def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements + # type(str, Any) -> None + parsed_url = urlparse(endpoint) + self._hostname = parsed_url.hostname + endpoint = self._hostname + if parsed_url.port: + self._port = parsed_url.port + elif parsed_url.scheme == "amqps": + self._port = SECURE_PORT + else: + self._port = PORT + self.state = None # type: Optional[ConnectionState] + + # Custom Endpoint + custom_endpoint_address = kwargs.get("custom_endpoint_address") + custom_endpoint = None + if custom_endpoint_address: + custom_parsed_url = urlparse(custom_endpoint_address) + custom_port = custom_parsed_url.port or WEBSOCKET_PORT + custom_endpoint = f"{custom_parsed_url.hostname}:{custom_port}{custom_parsed_url.path}" + self._container_id = kwargs.pop("container_id", None) or str(uuid.uuid4()) # type: str + self._network_trace = kwargs.get("network_trace", False) + self._network_trace_params = {"amqpConnection": self._container_id, "amqpSession": None, "amqpLink": None} + + transport = kwargs.get("transport") + self._transport_type = kwargs.pop("transport_type", TransportType.Amqp) + if transport: + self._transport = transport + elif "sasl_credential" in kwargs: + sasl_transport = SASLTransport + if self._transport_type.name == "AmqpOverWebsocket" or kwargs.get("http_proxy"): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path + self._transport = sasl_transport( + host=endpoint, + credential=kwargs["sasl_credential"], + custom_endpoint=custom_endpoint, + network_trace_params=self._network_trace_params, + **kwargs + ) + else: + self._transport = Transport( + parsed_url.netloc, + transport_type=self._transport_type, + network_trace_params=self._network_trace_params, + **kwargs) + self._max_frame_size = kwargs.pop("max_frame_size", MAX_FRAME_SIZE_BYTES) # type: int + self._remote_max_frame_size = None # type: Optional[int] + self._channel_max = kwargs.pop("channel_max", MAX_CHANNELS) # type: int + self._idle_timeout = kwargs.pop("idle_timeout", None) # type: Optional[int] + self._outgoing_locales = kwargs.pop("outgoing_locales", None) # type: Optional[List[str]] + self._incoming_locales = kwargs.pop("incoming_locales", None) # type: Optional[List[str]] + self._offered_capabilities = None # type: Optional[str] + self._desired_capabilities = kwargs.pop("desired_capabilities", None) # type: Optional[str] + self._properties = kwargs.pop("properties", None) # type: Optional[Dict[str, str]] + self._remote_properties: Optional[Dict[str, str]] = None + + self._allow_pipelined_open = kwargs.pop("allow_pipelined_open", True) # type: bool + self._remote_idle_timeout = None # type: Optional[int] + self._remote_idle_timeout_send_frame = None # type: Optional[int] + self._idle_timeout_empty_frame_send_ratio = kwargs.get("idle_timeout_empty_frame_send_ratio", 0.5) + self._last_frame_received_time = None # type: Optional[float] + self._last_frame_sent_time = None # type: Optional[float] + self._idle_wait_time = kwargs.get("idle_wait_time", 0.1) # type: float + self._error = None + self._outgoing_endpoints = {} # type: Dict[int, Session] + self._incoming_endpoints = {} # type: Dict[int, Session] + + def __enter__(self): + self.open() + return self + + def __exit__(self, *args): + self.close() + + def _set_state(self, new_state): + # type: (ConnectionState) -> None + """Update the connection state.""" + if new_state is None: + return + previous_state = self.state + self.state = new_state + _LOGGER.info( + "Connection state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) + for session in self._outgoing_endpoints.values(): + session._on_connection_state_change() # pylint:disable=protected-access + + def _connect(self): + # type: () -> None + """Initiate the connection. + + If `allow_pipelined_open` is enabled, the incoming response header will be processed immediately + and the state on exiting will be HDR_EXCH. Otherwise, the function will return before waiting for + the response header and the final state will be HDR_SENT. + + :raises ValueError: If a reciprocating protocol header is not received during negotiation. + """ + try: + if not self.state: + self._transport.connect() + self._set_state(ConnectionState.START) + self._transport.negotiate() + self._outgoing_header() + self._set_state(ConnectionState.HDR_SENT) + if not self._allow_pipelined_open: + # TODO: List/tuple expected as variable args + self._read_frame(wait=True) + if self.state != ConnectionState.HDR_EXCH: + self._disconnect() + raise ValueError("Did not receive reciprocal protocol header. Disconnecting.") + else: + self._set_state(ConnectionState.HDR_SENT) + except (OSError, IOError, SSLError, socket.error) as exc: + # FileNotFoundError is being raised for exception parity with uamqp when invalid + # `connection_verify` file path is passed in. Remove later when resolving issue #27128. + if isinstance(exc, FileNotFoundError) and exc.filename and "ca_certs" in exc.filename: + raise + raise AMQPConnectionError( + ErrorCondition.SocketError, + description="Failed to initiate the connection due to exception: " + str(exc), + error=exc, + ) + + def _disconnect(self): + # type: () -> None + """Disconnect the transport and set state to END.""" + if self.state == ConnectionState.END: + return + self._set_state(ConnectionState.END) + self._transport.close() + + def _can_read(self): + # type: () -> bool + """Whether the connection is in a state where it is legal to read for incoming frames.""" + return self.state not in (ConnectionState.CLOSE_RCVD, ConnectionState.END) + + def _read_frame( + self, wait: Union[bool, float] = True, **kwargs: Any + ) -> bool: + """Read an incoming frame from the transport. + + :param Union[bool, float] wait: Whether to block on the socket while waiting for an incoming frame. + The default value is `False`, where the frame will block for the configured timeout only (0.1 seconds). + If set to `True`, socket will block indefinitely. If set to a timeout value in seconds, the socket will + block for at most that value. + :rtype: Tuple[int, Optional[Tuple[int, NamedTuple]]] + :returns: A tuple with the incoming channel number, and the frame in the form or a tuple of performative + descriptor and field values. + """ + if wait is False: + new_frame = self._transport.receive_frame(**kwargs) + elif wait is True: + with self._transport.block(): + new_frame = self._transport.receive_frame(**kwargs) + else: + with self._transport.block_with_timeout(timeout=wait): + new_frame = self._transport.receive_frame(**kwargs) + return self._process_incoming_frame(*new_frame) + + def _can_write(self): + # type: () -> bool + """Whether the connection is in a state where it is legal to write outgoing frames.""" + return self.state not in _CLOSING_STATES + + def _send_frame(self, channel, frame, timeout=None, **kwargs): + # type: (int, NamedTuple, Optional[int], Any) -> None + """Send a frame over the connection. + + :param int channel: The outgoing channel number. + :param NamedTuple: The outgoing frame. + :param int timeout: An optional timeout value to wait until the socket is ready to send the frame. + :rtype: None + """ + try: + raise self._error + except TypeError: + pass + + if self._can_write(): + try: + self._last_frame_sent_time = time.time() + if timeout: + with self._transport.block_with_timeout(timeout): + self._transport.send_frame(channel, frame, **kwargs) + else: + self._transport.send_frame(channel, frame, **kwargs) + except (OSError, IOError, SSLError, socket.error) as exc: + self._error = AMQPConnectionError( + ErrorCondition.SocketError, + description="Can not send frame out due to exception: " + str(exc), + error=exc, + ) + except Exception: # pylint:disable=try-except-raise + raise + else: + _LOGGER.info("Cannot write frame in current state: %r", self.state, extra=self._network_trace_params) + + def _get_next_outgoing_channel(self): + # type: () -> int + """Get the next available outgoing channel number within the max channel limit. + + :raises ValueError: If maximum channels has been reached. + :returns: The next available outgoing channel number. + :rtype: int + """ + if (len(self._incoming_endpoints) + len(self._outgoing_endpoints)) >= self._channel_max: + raise ValueError("Maximum number of channels ({}) has been reached.".format(self._channel_max)) + next_channel = next(i for i in range(1, self._channel_max) if i not in self._outgoing_endpoints) + return next_channel + + def _outgoing_empty(self): + # type: () -> None + """Send an empty frame to prevent the connection from reaching an idle timeout.""" + if self._network_trace: + _LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params) + try: + raise self._error + except TypeError: + pass + try: + if self._can_write(): + self._transport.write(EMPTY_FRAME) + self._last_frame_sent_time = time.time() + except (OSError, IOError, SSLError, socket.error) as exc: + self._error = AMQPConnectionError( + ErrorCondition.SocketError, + description="Can not send empty frame due to exception: " + str(exc), + error=exc, + ) + except Exception: # pylint:disable=try-except-raise + raise + + def _outgoing_header(self): + # type: () -> None + """Send the AMQP protocol header to initiate the connection.""" + self._last_frame_sent_time = time.time() + if self._network_trace: + _LOGGER.debug("-> Header(%r)", HEADER_FRAME, extra=self._network_trace_params) + self._transport.write(HEADER_FRAME) + + def _incoming_header(self, _, frame): + # type: (int, bytes) -> None + """Process an incoming AMQP protocol header and update the connection state.""" + if self._network_trace: + _LOGGER.debug("<- Header(%r)", frame, extra=self._network_trace_params) + if self.state == ConnectionState.START: + self._set_state(ConnectionState.HDR_RCVD) + elif self.state == ConnectionState.HDR_SENT: + self._set_state(ConnectionState.HDR_EXCH) + elif self.state == ConnectionState.OPEN_PIPE: + self._set_state(ConnectionState.OPEN_SENT) + + def _outgoing_open(self): + # type: () -> None + """Send an Open frame to negotiate the AMQP connection functionality.""" + open_frame = OpenFrame( + container_id=self._container_id, + hostname=self._hostname, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout * 1000 if self._idle_timeout else None, # Convert to milliseconds + outgoing_locales=self._outgoing_locales, + incoming_locales=self._incoming_locales, + offered_capabilities=self._offered_capabilities if self.state == ConnectionState.OPEN_RCVD else None, + desired_capabilities=self._desired_capabilities if self.state == ConnectionState.HDR_EXCH else None, + properties=self._properties, + ) + if self._network_trace: + _LOGGER.debug("-> %r", open_frame, extra=self._network_trace_params) + self._send_frame(0, open_frame) + + def _incoming_open(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming Open frame to finish the connection negotiation. + + The incoming frame format is:: + + - frame[0]: container_id (str) + - frame[1]: hostname (str) + - frame[2]: max_frame_size (int) + - frame[3]: channel_max (int) + - frame[4]: idle_timeout (Optional[int]) + - frame[5]: outgoing_locales (Optional[List[bytes]]) + - frame[6]: incoming_locales (Optional[List[bytes]]) + - frame[7]: offered_capabilities (Optional[List[bytes]]) + - frame[8]: desired_capabilities (Optional[List[bytes]]) + - frame[9]: properties (Optional[Dict[bytes, bytes]]) + + :param int channel: The incoming channel number. + :param frame: The incoming Open frame. + :type frame: Tuple[Any, ...] + :rtype: None + """ + # TODO: Add type hints for full frame tuple contents. + if self._network_trace: + _LOGGER.debug("<- %r", OpenFrame(*frame), extra=self._network_trace_params) + if channel != 0: + _LOGGER.error("OPEN frame received on a channel that is not 0.", extra=self._network_trace_params) + self.close( + error=AMQPError( + condition=ErrorCondition.NotAllowed, description="OPEN frame received on a channel that is not 0." + ) + ) + self._set_state(ConnectionState.END) + if self.state == ConnectionState.OPENED: + _LOGGER.error("OPEN frame received in the OPENED state.", extra=self._network_trace_params) + self.close() + if frame[4]: + self._remote_idle_timeout = frame[4] / 1000 # Convert to seconds + self._remote_idle_timeout_send_frame = ( + self._idle_timeout_empty_frame_send_ratio * self._remote_idle_timeout + ) + + if frame[2] < 512: + # Max frame size is less than supported minimum. + # If any of the values in the received open frame are invalid then the connection shall be closed. + # The error amqp:invalid-field shall be set in the error.condition field of the CLOSE frame. + self.close( + error=AMQPError( + condition=ErrorCondition.InvalidField, + description="Failed parsing OPEN frame: Max frame size is less than supported minimum.", + ) + ) + _LOGGER.error( + "Failed parsing OPEN frame: Max frame size is less than supported minimum.", + extra=self._network_trace_params + ) + return + self._remote_max_frame_size = frame[2] + self._remote_properties = frame[9] + if self.state == ConnectionState.OPEN_SENT: + self._set_state(ConnectionState.OPENED) + elif self.state == ConnectionState.HDR_EXCH: + self._set_state(ConnectionState.OPEN_RCVD) + self._outgoing_open() + self._set_state(ConnectionState.OPENED) + else: + self.close( + error=AMQPError( + condition=ErrorCondition.IllegalState, + description=f"connection is an illegal state: {self.state}", + ) + ) + _LOGGER.error("Connection is an illegal state: %r", self.state, extra=self._network_trace_params) + + def _outgoing_close(self, error=None): + # type: (Optional[AMQPError]) -> None + """Send a Close frame to shutdown connection with optional error information.""" + close_frame = CloseFrame(error=error) + if self._network_trace: + _LOGGER.debug("-> %r", close_frame, extra=self._network_trace_params) + self._send_frame(0, close_frame) + + def _incoming_close(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming Open frame to finish the connection negotiation. + + The incoming frame format is:: + + - frame[0]: error (Optional[AMQPError]) + + """ + if self._network_trace: + _LOGGER.debug("<- %r", CloseFrame(*frame), extra=self._network_trace_params) + disconnect_states = [ + ConnectionState.HDR_RCVD, + ConnectionState.HDR_EXCH, + ConnectionState.OPEN_RCVD, + ConnectionState.CLOSE_SENT, + ConnectionState.DISCARDING, + ] + if self.state in disconnect_states: + self._disconnect() + return + + close_error = None + if channel > self._channel_max: + _LOGGER.error( + "CLOSE frame received on a channel greated than support max.", + extra=self._network_trace_params + ) + close_error = AMQPError(condition=ErrorCondition.InvalidField, description="Invalid channel", info=None) + + self._set_state(ConnectionState.CLOSE_RCVD) + self._outgoing_close(error=close_error) + self._disconnect() + + if frame[0]: + self._error = AMQPConnectionError( + condition=frame[0][0], description=frame[0][1], info=frame[0][2] + ) + _LOGGER.error( + "Connection closed with error: %r", frame[0], + extra=self._network_trace_params + ) + + + def _incoming_begin(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming Begin frame to finish negotiating a new session. + + The incoming frame format is:: + + - frame[0]: remote_channel (int) + - frame[1]: next_outgoing_id (int) + - frame[2]: incoming_window (int) + - frame[3]: outgoing_window (int) + - frame[4]: handle_max (int) + - frame[5]: offered_capabilities (Optional[List[bytes]]) + - frame[6]: desired_capabilities (Optional[List[bytes]]) + - frame[7]: properties (Optional[Dict[bytes, bytes]]) + + :param int channel: The incoming channel number. + :param frame: The incoming Begin frame. + :type frame: Tuple[Any, ...] + :rtype: None + """ + try: + existing_session = self._outgoing_endpoints[frame[0]] + self._incoming_endpoints[channel] = existing_session + self._incoming_endpoints[channel]._incoming_begin( # pylint:disable=protected-access + frame + ) + except KeyError: + new_session = Session.from_incoming_frame(self, channel) + self._incoming_endpoints[channel] = new_session + new_session._incoming_begin(frame) # pylint:disable=protected-access + + def _incoming_end(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming End frame to close a session. + + The incoming frame format is:: + + - frame[0]: error (Optional[AMQPError]) + + :param int channel: The incoming channel number. + :param frame: The incoming End frame. + :type frame: Tuple[Any, ...] + :rtype: None + """ + try: + self._incoming_endpoints[channel]._incoming_end(frame) # pylint:disable=protected-access + self._incoming_endpoints.pop(channel) + self._outgoing_endpoints.pop(channel) + except KeyError: + #close the connection + self.close( + error=AMQPError( + condition=ErrorCondition.ConnectionCloseForced, + description="Invalid channel number received" + )) + _LOGGER.error( + "END frame received on invalid channel. Closing connection.", + extra=self._network_trace_params + ) + return + + def _process_incoming_frame(self, channel, frame): # pylint:disable=too-many-return-statements + # type: (int, Optional[Union[bytes, Tuple[int, Tuple[Any, ...]]]]) -> bool + """Process an incoming frame, either directly or by passing to the necessary Session. + + :param int channel: The channel the frame arrived on. + :param frame: A tuple containing the performative descriptor and the field values of the frame. + This parameter can be None in the case of an empty frame or a socket timeout. + :type frame: Optional[Tuple[int, NamedTuple]] + :rtype: bool + :returns: A boolean to indicate whether more frames in a batch can be processed or whether the + incoming frame has altered the state. If `True` is returned, the state has changed and the batch + should be interrupted. + """ + try: + performative, fields = cast(Union[bytes, Tuple], frame) + except TypeError: + return True # Empty Frame or socket timeout + fields = cast(Tuple[Any, ...], fields) + try: + self._last_frame_received_time = time.time() + if performative == 20: + self._incoming_endpoints[channel]._incoming_transfer( # pylint:disable=protected-access + fields + ) + return False + if performative == 21: + self._incoming_endpoints[channel]._incoming_disposition( # pylint:disable=protected-access + fields + ) + return False + if performative == 19: + self._incoming_endpoints[channel]._incoming_flow( # pylint:disable=protected-access + fields + ) + return False + if performative == 18: + self._incoming_endpoints[channel]._incoming_attach( # pylint:disable=protected-access + fields + ) + return False + if performative == 22: + self._incoming_endpoints[channel]._incoming_detach( # pylint:disable=protected-access + fields + ) + return True + if performative == 17: + self._incoming_begin(channel, fields) + return True + if performative == 23: + self._incoming_end(channel, fields) + return True + if performative == 16: + self._incoming_open(channel, fields) + return True + if performative == 24: + self._incoming_close(channel, fields) + return True + if performative == 0: + self._incoming_header(channel, cast(bytes, fields)) + return True + if performative == 1: + return False + _LOGGER.error("Unrecognized incoming frame: %r", frame, extra=self._network_trace_params) + return True + except KeyError: + return True # TODO: channel error + + def _process_outgoing_frame(self, channel, frame): + # type: (int, NamedTuple) -> None + """Send an outgoing frame if the connection is in a legal state. + + :raises ValueError: If the connection is not open or not in a valid state. + """ + if not self._allow_pipelined_open and self.state in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ]: + raise ValueError("Connection not configured to allow pipeline send.") + if self.state not in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ConnectionState.OPENED, + ]: + raise ValueError("Connection not open.") + now = time.time() + if get_local_timeout( + now, + cast(float, self._idle_timeout), + cast(float, self._last_frame_received_time), + ) or self._get_remote_timeout(now): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) + self.close( + error=AMQPError( + condition=ErrorCondition.ConnectionCloseForced, + description="No frame received for the idle timeout.", + ), + wait=False, + ) + return + self._send_frame(channel, frame) + + def _get_remote_timeout(self, now): + # type: (float) -> bool + """Check whether the local connection has reached the remote endpoints idle timeout since + the last outgoing frame was sent. + + If the time since the last since frame is greater than the allowed idle interval, an Empty + frame will be sent to maintain the connection. + + :param float now: The current time to check against. + :rtype: bool + :returns: Whether the local connection should be shutdown due to timeout. + """ + if self._remote_idle_timeout and self._last_frame_sent_time: + time_since_last_sent = now - self._last_frame_sent_time + if time_since_last_sent > cast(int, self._remote_idle_timeout_send_frame): + self._outgoing_empty() + return False + + def _wait_for_response(self, wait, end_state): + # type: (Union[bool, float], ConnectionState) -> None + """Wait for an incoming frame to be processed that will result in a desired state change. + + :param wait: Whether to wait for an incoming frame to be processed. Can be set to `True` to wait + indefinitely, or an int to wait for a specified amount of time (in seconds). To not wait, set to `False`. + :type wait: bool or float + :param ConnectionState end_state: The desired end state to wait until. + :rtype: None + """ + if wait is True: + self.listen(wait=False) + while self.state != end_state: + time.sleep(self._idle_wait_time) + self.listen(wait=False) + elif wait: + self.listen(wait=False) + timeout = time.time() + wait + while self.state != end_state: + if time.time() >= timeout: + break + time.sleep(self._idle_wait_time) + self.listen(wait=False) + + def listen(self, wait=False, batch=1, **kwargs): + # type: (Union[float, int, bool], int, Any) -> None + """Listen on the socket for incoming frames and process them. + + :param wait: Whether to block on the socket until a frame arrives. If set to `True`, socket will + block indefinitely. Alternatively, if set to a time in seconds, the socket will block for at most + the specified timeout. Default value is `False`, where the socket will block for its configured read + timeout (by default 0.1 seconds). + :type wait: int or float or bool + :param int batch: The number of frames to attempt to read and process before returning. The default value + is 1, i.e. process frames one-at-a-time. A higher value should only be used when a receiver is established + and is processing incoming Transfer frames. + :rtype: None + """ + try: + raise self._error + except TypeError: + pass + try: + if self.state not in _CLOSING_STATES: + now = time.time() + if get_local_timeout( + now, + cast(float, self._idle_timeout), + cast(float, self._last_frame_received_time), + ) or self._get_remote_timeout( + now + ): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) + self.close( + error=AMQPError( + condition=ErrorCondition.ConnectionCloseForced, + description="No frame received for the idle timeout.", + ), + wait=False, + ) + return + if self.state == ConnectionState.END: + self._error = AMQPConnectionError( + condition=ErrorCondition.ConnectionCloseForced, description="Connection was already closed." + ) + return + for _ in range(batch): + if self._can_read(): + if self._read_frame(wait=wait, **kwargs): + break + else: + _LOGGER.info( + "Connection cannot read frames in this state: %r", + self.state, + extra=self._network_trace_params + ) + break + except (OSError, IOError, SSLError, socket.error) as exc: + self._error = AMQPConnectionError( + ErrorCondition.SocketError, + description="Can not read frame due to exception: " + str(exc), + error=exc, + ) + except Exception: # pylint:disable=try-except-raise + raise + + def create_session(self, **kwargs): + # type: (Any) -> Session + """Create a new session within this connection. + + :keyword str name: The name of the connection. If not set a GUID will be generated. + :keyword int next_outgoing_id: The transfer-id of the first transfer id the sender will send. + Default value is 0. + :keyword int incoming_window: The initial incoming-window of the Session. Default value is 1. + :keyword int outgoing_window: The initial outgoing-window of the Session. Default value is 1. + :keyword int handle_max: The maximum handle value that may be used on the session. Default value is 4294967295. + :keyword list(str) offered_capabilities: The extension capabilities the session supports. + :keyword list(str) desired_capabilities: The extension capabilities the session may use if + the endpoint supports it. + :keyword dict properties: Session properties. + :keyword bool allow_pipelined_open: Allow frames to be sent on the connection before a response Open frame + has been received. Default value is that configured for the connection. + :keyword float idle_wait_time: The time in seconds to sleep while waiting for a response from the endpoint. + Default value is that configured for the connection. + :keyword bool network_trace: Whether to log the network traffic of this session. If enabled, frames + will be logged at the logging.INFO level. Default value is that configured for the connection. + """ + assigned_channel = self._get_next_outgoing_channel() + kwargs["allow_pipelined_open"] = self._allow_pipelined_open + kwargs["idle_wait_time"] = self._idle_wait_time + session = Session( + self, + assigned_channel, + network_trace=kwargs.pop("network_trace", self._network_trace), + network_trace_params=dict(self._network_trace_params), + **kwargs, + ) + self._outgoing_endpoints[assigned_channel] = session + return session + + def open(self, wait=False): + # type: (bool) -> None + """Send an Open frame to start the connection. + + Alternatively, this will be called on entering a Connection context manager. + + :param bool wait: Whether to wait to receive an Open response from the endpoint. Default is `False`. + :raises ValueError: If `wait` is set to `False` and `allow_pipelined_open` is disabled. + :rtype: None + """ + self._connect() + self._outgoing_open() + if self.state == ConnectionState.HDR_EXCH: + self._set_state(ConnectionState.OPEN_SENT) + elif self.state == ConnectionState.HDR_SENT: + self._set_state(ConnectionState.OPEN_PIPE) + if wait: + self._wait_for_response(wait, ConnectionState.OPENED) + elif not self._allow_pipelined_open: + raise ValueError( + "Connection has been configured to not allow piplined-open. Please set 'wait' parameter." + ) + + + def close(self, error=None, wait=False): + # type: (Optional[AMQPError], bool) -> None + """Close the connection and disconnect the transport. + + Alternatively this method will be called on exiting a Connection context manager. + + :param ~uamqp.AMQPError error: Optional error information to include in the close request. + :param bool wait: Whether to wait for a service Close response. Default is `False`. + :rtype: None + """ + try: + if self.state in [ + ConnectionState.END, + ConnectionState.CLOSE_SENT, + ConnectionState.DISCARDING, + ]: + return + self._outgoing_close(error=error) + if error: + self._error = AMQPConnectionError( + condition=error.condition, + description=error.description, + info=error.info, + ) + if self.state == ConnectionState.OPEN_PIPE: + self._set_state(ConnectionState.OC_PIPE) + elif self.state == ConnectionState.OPEN_SENT: + self._set_state(ConnectionState.CLOSE_PIPE) + elif error: + self._set_state(ConnectionState.DISCARDING) + else: + self._set_state(ConnectionState.CLOSE_SENT) + self._wait_for_response(wait, ConnectionState.END) + except Exception as exc: # pylint:disable=broad-except + # If error happened during closing, ignore the error and set state to END + _LOGGER.info("An error occurred when closing the connection: %r", exc, extra=self._network_trace_params) + self._set_state(ConnectionState.END) + finally: + self._disconnect() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py new file mode 100644 index 0000000000000..0990697128652 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py @@ -0,0 +1,349 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- +# pylint: disable=redefined-builtin, import-error + +import struct +import uuid +import logging +from typing import List, Optional, Tuple, Dict, Callable, Any, cast, Union # pylint: disable=unused-import + + +from .message import Message, Header, Properties + +_LOGGER = logging.getLogger(__name__) +_HEADER_PREFIX = memoryview(b'AMQP') +_COMPOSITES = { + 35: 'received', + 36: 'accepted', + 37: 'rejected', + 38: 'released', + 39: 'modified', +} + +c_unsigned_char = struct.Struct('>B') +c_signed_char = struct.Struct('>b') +c_unsigned_short = struct.Struct('>H') +c_signed_short = struct.Struct('>h') +c_unsigned_int = struct.Struct('>I') +c_signed_int = struct.Struct('>i') +c_unsigned_long = struct.Struct('>L') +c_unsigned_long_long = struct.Struct('>Q') +c_signed_long_long = struct.Struct('>q') +c_float = struct.Struct('>f') +c_double = struct.Struct('>d') + + +def _decode_null(buffer): + # type: (memoryview) -> Tuple[memoryview, None] + return buffer, None + + +def _decode_true(buffer): + # type: (memoryview) -> Tuple[memoryview, bool] + return buffer, True + + +def _decode_false(buffer): + # type: (memoryview) -> Tuple[memoryview, bool] + return buffer, False + + +def _decode_zero(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer, 0 + + +def _decode_empty(buffer): + # type: (memoryview) -> Tuple[memoryview, List[None]] + return buffer, [] + + +def _decode_boolean(buffer): + # type: (memoryview) -> Tuple[memoryview, bool] + return buffer[1:], buffer[:1] == b'\x01' + + +def _decode_ubyte(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[1:], buffer[0] + + +def _decode_ushort(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[2:], c_unsigned_short.unpack(buffer[:2])[0] + + +def _decode_uint_small(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[1:], buffer[0] + + +def _decode_uint_large(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[4:], c_unsigned_int.unpack(buffer[:4])[0] + + +def _decode_ulong_small(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[1:], buffer[0] + + +def _decode_ulong_large(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[8:], c_unsigned_long_long.unpack(buffer[:8])[0] + + +def _decode_byte(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[1:], c_signed_char.unpack(buffer[:1])[0] + + +def _decode_short(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[2:], c_signed_short.unpack(buffer[:2])[0] + + +def _decode_int_small(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[1:], c_signed_char.unpack(buffer[:1])[0] + + +def _decode_int_large(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[4:], c_signed_int.unpack(buffer[:4])[0] + + +def _decode_long_small(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[1:], c_signed_char.unpack(buffer[:1])[0] + + +def _decode_long_large(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[8:], c_signed_long_long.unpack(buffer[:8])[0] + + +def _decode_float(buffer): + # type: (memoryview) -> Tuple[memoryview, float] + return buffer[4:], c_float.unpack(buffer[:4])[0] + + +def _decode_double(buffer): + # type: (memoryview) -> Tuple[memoryview, float] + return buffer[8:], c_double.unpack(buffer[:8])[0] + + +def _decode_timestamp(buffer): + # type: (memoryview) -> Tuple[memoryview, int] + return buffer[8:], c_signed_long_long.unpack(buffer[:8])[0] + + +def _decode_uuid(buffer): + # type: (memoryview) -> Tuple[memoryview, uuid.UUID] + return buffer[16:], uuid.UUID(bytes=buffer[:16].tobytes()) + + +def _decode_binary_small(buffer): + # type: (memoryview) -> Tuple[memoryview, bytes] + length_index = buffer[0] + 1 + return buffer[length_index:], buffer[1:length_index].tobytes() + + +def _decode_binary_large(buffer): + # type: (memoryview) -> Tuple[memoryview, bytes] + length_index = c_unsigned_long.unpack(buffer[:4])[0] + 4 + return buffer[length_index:], buffer[4:length_index].tobytes() + + +def _decode_list_small(buffer): + # type: (memoryview) -> Tuple[memoryview, List[Any]] + count = buffer[1] + buffer = buffer[2:] + values = [None] * count + for i in range(count): + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + return buffer, values + + +def _decode_list_large(buffer): + # type: (memoryview) -> Tuple[memoryview, List[Any]] + count = c_unsigned_long.unpack(buffer[4:8])[0] + buffer = buffer[8:] + values = [None] * count + for i in range(count): + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + return buffer, values + + +def _decode_map_small(buffer): + # type: (memoryview) -> Tuple[memoryview, Dict[Any, Any]] + count = int(buffer[1]/2) + buffer = buffer[2:] + values = {} + for _ in range(count): + buffer, key = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + values[key] = value + return buffer, values + + +def _decode_map_large(buffer): + # type: (memoryview) -> Tuple[memoryview, Dict[Any, Any]] + count = int(c_unsigned_long.unpack(buffer[4:8])[0]/2) + buffer = buffer[8:] + values = {} + for _ in range(count): + buffer, key = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + values[key] = value + return buffer, values + + +def _decode_array_small(buffer): + # type: (memoryview) -> Tuple[memoryview, List[Any]] + count = buffer[1] # Ignore first byte (size) and just rely on count + if count: + subconstructor = buffer[2] + buffer = buffer[3:] + values = [None] * count + for i in range(count): + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) + return buffer, values + return buffer[2:], [] + + +def _decode_array_large(buffer): + # type: (memoryview) -> Tuple[memoryview, List[Any]] + count = c_unsigned_long.unpack(buffer[4:8])[0] + if count: + subconstructor = buffer[8] + buffer = buffer[9:] + values = [None] * count + for i in range(count): + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) + return buffer, values + return buffer[8:], [] + + +def _decode_described(buffer): + # type: (memoryview) -> Tuple[memoryview, Any] + # TODO: to move the cursor of the buffer to the described value based on size of the + # descriptor without decoding descriptor value + composite_type = buffer[0] + buffer, descriptor = _DECODE_BY_CONSTRUCTOR[composite_type](buffer[1:]) + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + try: + composite_type = cast(int, _COMPOSITES[descriptor]) + return buffer, {composite_type: value} + except KeyError: + return buffer, value + + +def decode_payload(buffer): + # type: (memoryview) -> Message + message: Dict[str, Union[Properties, Header, Dict, bytes, List]] = {} + while buffer: + # Ignore the first two bytes, they will always be the constructors for + # described type then ulong. + descriptor = buffer[2] + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[3]](buffer[4:]) + if descriptor == 112: + message["header"] = Header(*value) + elif descriptor == 113: + message["delivery_annotations"] = value + elif descriptor == 114: + message["message_annotations"] = value + elif descriptor == 115: + message["properties"] = Properties(*value) + elif descriptor == 116: + message["application_properties"] = value + elif descriptor == 117: + try: + cast(List, message["data"]).append(value) + except KeyError: + message["data"] = [value] + elif descriptor == 118: + try: + cast(List, message["sequence"]).append(value) + except KeyError: + message["sequence"] = [value] + elif descriptor == 119: + message["value"] = value + elif descriptor == 120: + message["footer"] = value + # TODO: we can possibly swap out the Message construct with a TypedDict + # for both input and output so we get the best of both. + return Message(**message) + + +def decode_frame(data): + # type: (memoryview) -> Tuple[int, List[Any]] + # Ignore the first two bytes, they will always be the constructors for + # described type then ulong. + frame_type = data[2] + compound_list_type = data[3] + if compound_list_type == 0xd0: + # list32 0xd0: data[4:8] is size, data[8:12] is count + count = c_signed_int.unpack(data[8:12])[0] + buffer = data[12:] + else: + # list8 0xc0: data[4] is size, data[5] is count + count = data[5] + buffer = data[6:] + fields: List[Optional[memoryview]] = [None] * count + for i in range(count): + buffer, fields[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + if frame_type == 20: + fields.append(buffer) + return frame_type, fields + + +def decode_empty_frame(header): + # type: (memoryview) -> Tuple[int, bytes] + if header[0:4] == _HEADER_PREFIX: + return 0, header.tobytes() + if header[5] == 0: + return 1, b"EMPTY" + raise ValueError("Received unrecognized empty frame") + + +_DECODE_BY_CONSTRUCTOR: List[Callable] = cast(List[Callable], [None] * 256) +_DECODE_BY_CONSTRUCTOR[0] = _decode_described +_DECODE_BY_CONSTRUCTOR[64] = _decode_null +_DECODE_BY_CONSTRUCTOR[65] = _decode_true +_DECODE_BY_CONSTRUCTOR[66] = _decode_false +_DECODE_BY_CONSTRUCTOR[67] = _decode_zero +_DECODE_BY_CONSTRUCTOR[68] = _decode_zero +_DECODE_BY_CONSTRUCTOR[69] = _decode_empty +_DECODE_BY_CONSTRUCTOR[80] = _decode_ubyte +_DECODE_BY_CONSTRUCTOR[81] = _decode_byte +_DECODE_BY_CONSTRUCTOR[82] = _decode_uint_small +_DECODE_BY_CONSTRUCTOR[83] = _decode_ulong_small +_DECODE_BY_CONSTRUCTOR[84] = _decode_int_small +_DECODE_BY_CONSTRUCTOR[85] = _decode_long_small +_DECODE_BY_CONSTRUCTOR[86] = _decode_boolean +_DECODE_BY_CONSTRUCTOR[96] = _decode_ushort +_DECODE_BY_CONSTRUCTOR[97] = _decode_short +_DECODE_BY_CONSTRUCTOR[112] = _decode_uint_large +_DECODE_BY_CONSTRUCTOR[113] = _decode_int_large +_DECODE_BY_CONSTRUCTOR[114] = _decode_float +_DECODE_BY_CONSTRUCTOR[128] = _decode_ulong_large +_DECODE_BY_CONSTRUCTOR[129] = _decode_long_large +_DECODE_BY_CONSTRUCTOR[130] = _decode_double +_DECODE_BY_CONSTRUCTOR[131] = _decode_timestamp +_DECODE_BY_CONSTRUCTOR[152] = _decode_uuid +_DECODE_BY_CONSTRUCTOR[160] = _decode_binary_small +_DECODE_BY_CONSTRUCTOR[161] = _decode_binary_small +_DECODE_BY_CONSTRUCTOR[163] = _decode_binary_small +_DECODE_BY_CONSTRUCTOR[176] = _decode_binary_large +_DECODE_BY_CONSTRUCTOR[177] = _decode_binary_large +_DECODE_BY_CONSTRUCTOR[179] = _decode_binary_large +_DECODE_BY_CONSTRUCTOR[192] = _decode_list_small +_DECODE_BY_CONSTRUCTOR[193] = _decode_map_small +_DECODE_BY_CONSTRUCTOR[208] = _decode_list_large +_DECODE_BY_CONSTRUCTOR[209] = _decode_map_large +_DECODE_BY_CONSTRUCTOR[224] = _decode_array_small +_DECODE_BY_CONSTRUCTOR[240] = _decode_array_large diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py new file mode 100644 index 0000000000000..381fb26be6d86 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py @@ -0,0 +1,920 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# TODO: fix mypy errors for _code/_definition/__defaults__ (issue #26500) +import calendar +import struct +import uuid +from datetime import datetime +from typing import ( + Iterable, + Union, + Tuple, + Dict, + Any, + cast, + Sized, + Optional, + List, + Callable, + TYPE_CHECKING, + Sequence, + Collection, +) + +try: + from typing import TypeAlias # type: ignore +except ImportError: + from typing_extensions import TypeAlias + + +from .types import ( + TYPE, + VALUE, + AMQPTypes, + FieldDefinition, + ObjDefinition, + ConstructorBytes, +) +from .message import Message +from . import performatives + +if TYPE_CHECKING: + from .message import Header, Properties + + Performative: TypeAlias = Union[ + performatives.OpenFrame, + performatives.BeginFrame, + performatives.AttachFrame, + performatives.FlowFrame, + performatives.TransferFrame, + performatives.DispositionFrame, + performatives.DetachFrame, + performatives.EndFrame, + performatives.CloseFrame, + performatives.SASLMechanism, + performatives.SASLInit, + performatives.SASLChallenge, + performatives.SASLResponse, + performatives.SASLOutcome, + Message, + Header, + Properties, + ] + +_FRAME_OFFSET = b"\x02" +_FRAME_TYPE = b"\x00" + + +def _construct(byte, construct): + # type: (bytes, bool) -> bytes + return byte if construct else b"" + + +def encode_null(output, *args, **kwargs): # pylint: disable=unused-argument + # type: (bytearray, Any, Any) -> None + """ + encoding code="0x40" category="fixed" width="0" label="the null value" + """ + output.extend(ConstructorBytes.null) + + +def encode_boolean( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, bool, bool, Any) -> None + """ + + + + """ + value = bool(value) + if with_constructor: + output.extend(_construct(ConstructorBytes.bool, with_constructor)) + output.extend(b"\x01" if value else b"\x00") + return + + output.extend(ConstructorBytes.bool_true if value else ConstructorBytes.bool_false) + + +def encode_ubyte( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, Union[int, bytes], bool, Any) -> None + """ + + """ + try: + value = int(value) + except ValueError: + value = cast(bytes, value) + value = ord(value) + try: + output.extend(_construct(ConstructorBytes.ubyte, with_constructor)) + output.extend(struct.pack(">B", abs(value))) + except struct.error: + raise ValueError("Unsigned byte value must be 0-255") + + +def encode_ushort( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, int, bool, Any) -> None + """ + + """ + value = int(value) + try: + output.extend(_construct(ConstructorBytes.ushort, with_constructor)) + output.extend(struct.pack(">H", abs(value))) + except struct.error: + raise ValueError("Unsigned byte value must be 0-65535") + + +def encode_uint(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, int, bool, bool) -> None + """ + + + + """ + value = int(value) + if value == 0: + output.extend(ConstructorBytes.uint_0) + return + try: + if use_smallest and value <= 255: + output.extend(_construct(ConstructorBytes.uint_small, with_constructor)) + output.extend(struct.pack(">B", abs(value))) + return + output.extend(_construct(ConstructorBytes.uint_large, with_constructor)) + output.extend(struct.pack(">I", abs(value))) + except struct.error: + raise ValueError("Value supplied for unsigned int invalid: {}".format(value)) + + +def encode_ulong(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, int, bool, bool) -> None + """ + + + + """ + value = int(value) + if value == 0: + output.extend(ConstructorBytes.ulong_0) + return + try: + if use_smallest and value <= 255: + output.extend(_construct(ConstructorBytes.ulong_small, with_constructor)) + output.extend(struct.pack(">B", abs(value))) + return + output.extend(_construct(ConstructorBytes.ulong_large, with_constructor)) + output.extend(struct.pack(">Q", abs(value))) + except struct.error: + raise ValueError("Value supplied for unsigned long invalid: {}".format(value)) + + +def encode_byte( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, int, bool, Any) -> None + """ + + """ + value = int(value) + try: + output.extend(_construct(ConstructorBytes.byte, with_constructor)) + output.extend(struct.pack(">b", value)) + except struct.error: + raise ValueError("Byte value must be -128-127") + + +def encode_short( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, int, bool, Any) -> None + """ + + """ + value = int(value) + try: + output.extend(_construct(ConstructorBytes.short, with_constructor)) + output.extend(struct.pack(">h", value)) + except struct.error: + raise ValueError("Short value must be -32768-32767") + + +def encode_int(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, int, bool, bool) -> None + """ + + + """ + value = int(value) + try: + if use_smallest and (-128 <= value <= 127): + output.extend(_construct(ConstructorBytes.int_small, with_constructor)) + output.extend(struct.pack(">b", value)) + return + output.extend(_construct(ConstructorBytes.int_large, with_constructor)) + output.extend(struct.pack(">i", value)) + except struct.error: + raise ValueError("Value supplied for int invalid: {}".format(value)) + + +def encode_long(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, int, bool, bool) -> None + """ + + + """ + if isinstance(value, datetime): + value = (calendar.timegm(value.utctimetuple()) * 1000) + ( + value.microsecond / 1000 + ) + value = int(value) + try: + if use_smallest and (-128 <= value <= 127): + output.extend(_construct(ConstructorBytes.long_small, with_constructor)) + output.extend(struct.pack(">b", value)) + return + output.extend(_construct(ConstructorBytes.long_large, with_constructor)) + output.extend(struct.pack(">q", value)) + except struct.error: + raise ValueError("Value supplied for long invalid: {}".format(value)) + + +def encode_float( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, float, bool, Any) -> None + """ + + """ + value = float(value) + output.extend(_construct(ConstructorBytes.float, with_constructor)) + output.extend(struct.pack(">f", value)) + + +def encode_double( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, float, bool, Any) -> None + """ + + """ + value = float(value) + output.extend(_construct(ConstructorBytes.double, with_constructor)) + output.extend(struct.pack(">d", value)) + + +def encode_timestamp( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, Union[int, datetime], bool, Any) -> None + """ + + """ + value = cast(datetime, value) + if isinstance(value, datetime): + value = cast( + int, + (calendar.timegm(value.utctimetuple()) * 1000) + (value.microsecond / 1000), + ) + value = int(cast(int, value)) + output.extend(_construct(ConstructorBytes.timestamp, with_constructor)) + output.extend(struct.pack(">q", value)) + + +def encode_uuid( + output, value, with_constructor=True, **kwargs # pylint: disable=unused-argument +): + # type: (bytearray, Union[uuid.UUID, str, bytes], bool, Any) -> None + """ + + """ + if isinstance(value, str): + value = uuid.UUID(value).bytes + elif isinstance(value, uuid.UUID): + value = value.bytes + elif isinstance(value, bytes): + value = uuid.UUID(bytes=value).bytes + else: + raise TypeError("Invalid UUID type: {}".format(type(value))) + output.extend(_construct(ConstructorBytes.uuid, with_constructor)) + output.extend(value) + + +def encode_binary(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, Union[bytes, bytearray], bool, bool) -> None + """ + + + """ + length = len(value) + if use_smallest and length <= 255: + output.extend(_construct(ConstructorBytes.binary_small, with_constructor)) + output.extend(struct.pack(">B", length)) + output.extend(value) + return + try: + output.extend(_construct(ConstructorBytes.binary_large, with_constructor)) + output.extend(struct.pack(">L", length)) + output.extend(value) + except struct.error: + raise ValueError("Binary data to long to encode") + + +def encode_string(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, Union[bytes, str], bool, bool) -> None + """ + + + """ + if isinstance(value, str): + value = value.encode("utf-8") + length = len(value) + if use_smallest and length <= 255: + output.extend(_construct(ConstructorBytes.string_small, with_constructor)) + output.extend(struct.pack(">B", length)) + output.extend(value) + return + try: + output.extend(_construct(ConstructorBytes.string_large, with_constructor)) + output.extend(struct.pack(">L", length)) + output.extend(value) + except struct.error: + raise ValueError("String value too long to encode.") + + +def encode_symbol(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, Union[bytes, str], bool, bool) -> None + """ + + + """ + if isinstance(value, str): + value = value.encode("utf-8") + length = len(value) + if use_smallest and length <= 255: + output.extend(_construct(ConstructorBytes.symbol_small, with_constructor)) + output.extend(struct.pack(">B", length)) + output.extend(value) + return + try: + output.extend(_construct(ConstructorBytes.symbol_large, with_constructor)) + output.extend(struct.pack(">L", length)) + output.extend(value) + except struct.error: + raise ValueError("Symbol value too long to encode.") + + +def encode_list(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, Iterable[Any], bool, bool) -> None + """ + + + + """ + count = len(cast(Sized, value)) + if use_smallest and count == 0: + output.extend(ConstructorBytes.list_0) + return + encoded_size = 0 + encoded_values = bytearray() + for item in value: + encode_value(encoded_values, item, with_constructor=True) + encoded_size += len(encoded_values) + if use_smallest and count <= 255 and encoded_size < 255: + output.extend(_construct(ConstructorBytes.list_small, with_constructor)) + output.extend(struct.pack(">B", encoded_size + 1)) + output.extend(struct.pack(">B", count)) + else: + try: + output.extend(_construct(ConstructorBytes.list_large, with_constructor)) + output.extend(struct.pack(">L", encoded_size + 4)) + output.extend(struct.pack(">L", count)) + except struct.error: + raise ValueError("List is too large or too long to be encoded.") + output.extend(encoded_values) + +def encode_map(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, Union[Dict[Any, Any], Iterable[Tuple[Any, Any]]], bool, bool) -> None + """ + + + """ + count = len(cast(Sized, value)) * 2 + encoded_size = 0 + encoded_values = bytearray() + try: + value = cast(Dict, value) + items = cast(Iterable, value.items()) + except AttributeError: + items = cast(Iterable, value) + for key, data in items: + encode_value(encoded_values, key, with_constructor=True) + encode_value(encoded_values, data, with_constructor=True) + encoded_size = len(encoded_values) + if use_smallest and count <= 255 and encoded_size < 255: + output.extend(_construct(ConstructorBytes.map_small, with_constructor)) + output.extend(struct.pack(">B", encoded_size + 1)) + output.extend(struct.pack(">B", count)) + else: + try: + output.extend(_construct(ConstructorBytes.map_large, with_constructor)) + output.extend(struct.pack(">L", encoded_size + 4)) + output.extend(struct.pack(">L", count)) + except struct.error: + raise ValueError("Map is too large or too long to be encoded.") + output.extend(encoded_values) + + +def _check_element_type(item, element_type): + if not element_type: + try: + return item["TYPE"] + except (KeyError, TypeError): + return type(item) + try: + if item["TYPE"] != element_type: + raise TypeError("All elements in an array must be the same type.") + except (KeyError, TypeError): + if not isinstance(item, element_type): + raise TypeError("All elements in an array must be the same type.") + return element_type + + +def encode_array(output, value, with_constructor=True, use_smallest=True): + # type: (bytearray, Iterable[Any], bool, bool) -> None + """ + + + """ + count = len(cast(Sized, value)) + encoded_size = 0 + encoded_values = bytearray() + first_item = True + element_type = None + for item in value: + element_type = _check_element_type(item, element_type) + encode_value( + encoded_values, item, with_constructor=first_item, use_smallest=False + ) + first_item = False + if item is None: + encoded_size -= 1 + break + encoded_size += len(encoded_values) + if use_smallest and count <= 255 and encoded_size < 255: + output.extend(_construct(ConstructorBytes.array_small, with_constructor)) + output.extend(struct.pack(">B", encoded_size + 1)) + output.extend(struct.pack(">B", count)) + else: + try: + output.extend(_construct(ConstructorBytes.array_large, with_constructor)) + output.extend(struct.pack(">L", encoded_size + 4)) + output.extend(struct.pack(">L", count)) + except struct.error: + raise ValueError("Array is too large or too long to be encoded.") + output.extend(encoded_values) + + +def encode_described(output: bytearray, value: Tuple[Any, Any], _: bool = None, **kwargs: Any) -> None: # type: ignore + output.extend(ConstructorBytes.descriptor) + encode_value(output, value[0], **kwargs) + encode_value(output, value[1], **kwargs) + + +def encode_fields(value): + # type: (Optional[Dict[str, Any]]) -> Dict[str, Any] + """A mapping from field name to value. + + The fields type is a map where the keys are restricted to be of type symbol (this excludes the possibility + of a null key). There is no further restriction implied by the fields type on the allowed values for the + entries or the set of allowed keys. + + + """ + if not value: + return {TYPE: AMQPTypes.null, VALUE: None} + fields = {TYPE: AMQPTypes.map, VALUE: []} + for key, data in value.items(): + if isinstance(key, str): + key = key.encode("utf-8") # type: ignore + cast(List, fields[VALUE]).append(({TYPE: AMQPTypes.symbol, VALUE: key}, data)) + return fields + + +def encode_annotations(value): + # type: (Optional[Dict[str, Any]]) -> Dict[str, Any] + """The annotations type is a map where the keys are restricted to be of type symbol or of type ulong. + + All ulong keys, and all symbolic keys except those beginning with "x-" are reserved. + On receiving an annotations map containing keys or values which it does not recognize, and for which the + key does not begin with the string 'x-opt-' an AMQP container MUST detach the link with the not-implemented + amqp-error. + + + """ + if not value: + return {TYPE: AMQPTypes.null, VALUE: None} + fields = {TYPE: AMQPTypes.map, VALUE: []} + for key, data in value.items(): + if isinstance(key, int): + field_key = {TYPE: AMQPTypes.ulong, VALUE: key} + else: + field_key = {TYPE: AMQPTypes.symbol, VALUE: key} + try: + cast(List, fields[VALUE]).append( + (field_key, {TYPE: data[TYPE], VALUE: data[VALUE]}) + ) + except (KeyError, TypeError): + cast(List, fields[VALUE]).append((field_key, {TYPE: None, VALUE: data})) + return fields + + +def encode_application_properties(value): + # type: (Optional[Dict[str, Any]]) -> Dict[str, Any] + """The application-properties section is a part of the bare message used for structured application data. + + + + + + Intermediaries may use the data within this structure for the purposes of filtering or routing. + The keys of this map are restricted to be of type string (which excludes the possibility of a null key) + and the values are restricted to be of simple types only, that is (excluding map, list, and array types). + """ + if not value: + return {TYPE: AMQPTypes.null, VALUE: None} + fields = {TYPE: AMQPTypes.map, VALUE: cast(List, [])} + for key, data in value.items(): + cast(List, fields[VALUE]).append(({TYPE: AMQPTypes.string, VALUE: key}, data)) + return fields + + +def encode_message_id(value): + # type: (Any) -> Dict[str, Union[int, uuid.UUID, bytes, str]] + """ + + + + + """ + if isinstance(value, int): + return {TYPE: AMQPTypes.ulong, VALUE: value} + if isinstance(value, uuid.UUID): + return {TYPE: AMQPTypes.uuid, VALUE: value} + if isinstance(value, bytes): + return {TYPE: AMQPTypes.binary, VALUE: value} + if isinstance(value, str): + return {TYPE: AMQPTypes.string, VALUE: value} + raise TypeError("Unsupported Message ID type.") + + +def encode_node_properties(value): + # type: (Optional[Dict[str, Any]]) -> Dict[str, Any] + """Properties of a node. + + + + A symbol-keyed map containing properties of a node used when requesting creation or reporting + the creation of a dynamic node. The following common properties are defined:: + + - `lifetime-policy`: The lifetime of a dynamically generated node. Definitionally, the lifetime will + never be less than the lifetime of the link which caused its creation, however it is possible to extend + the lifetime of dynamically created node using a lifetime policy. The value of this entry MUST be of a type + which provides the lifetime-policy archetype. The following standard lifetime-policies are defined below: + delete-on-close, delete-on-no-links, delete-on-no-messages or delete-on-no-links-or-messages. + + - `supported-dist-modes`: The distribution modes that the node supports. The value of this entry MUST be one or + more symbols which are valid distribution-modes. That is, the value MUST be of the same type as would be valid + in a field defined with the following attributes: + type="symbol" multiple="true" requires="distribution-mode" + """ + if not value: + return {TYPE: AMQPTypes.null, VALUE: None} + # TODO + fields = {TYPE: AMQPTypes.map, VALUE: []} + # fields[{TYPE: AMQPTypes.symbol, VALUE: b'lifetime-policy'}] = { + # TYPE: AMQPTypes.described, + # VALUE: ( + # {TYPE: AMQPTypes.ulong, VALUE: value['lifetime_policy']}, + # {TYPE: AMQPTypes.list, VALUE: []} + # ) + # } + # fields[{TYPE: AMQPTypes.symbol, VALUE: b'supported-dist-modes'}] = {} + return fields + + +def encode_filter_set(value): + # type: (Optional[Dict[str, Any]]) -> Dict[str, Any] + """A set of predicates to filter the Messages admitted onto the Link. + + + + A set of named filters. Every key in the map MUST be of type symbol, every value MUST be either null or of a + described type which provides the archetype filter. A filter acts as a function on a message which returns a + boolean result indicating whether the message can pass through that filter or not. A message will pass + through a filter-set if and only if it passes through each of the named filters. If the value for a given key is + null, this acts as if there were no such key present (i.e., all messages pass through the null filter). + + Filter types are a defined extension point. The filter types that a given source supports will be indicated + by the capabilities of the source. + """ + if not value: + return {TYPE: AMQPTypes.null, VALUE: None} + fields = {TYPE: AMQPTypes.map, VALUE: cast(List, [])} + for name, data in value.items(): + described_filter: Dict[str, Union[Tuple[Dict[str, Any], Any], Optional[str]]] + if data is None: + described_filter = {TYPE: AMQPTypes.null, VALUE: None} + else: + if isinstance(name, str): + name = name.encode("utf-8") # type: ignore + try: + descriptor, filter_value = data + described_filter = { + TYPE: AMQPTypes.described, + VALUE: ({TYPE: AMQPTypes.symbol, VALUE: descriptor}, filter_value), + } + except ValueError: + described_filter = data + + cast(List, fields[VALUE]).append( + ({TYPE: AMQPTypes.symbol, VALUE: name}, described_filter) + ) + return fields + + +def encode_unknown(output, value, **kwargs): + # type: (bytearray, Optional[Any], Any) -> None + """ + Dynamic encoding according to the type of `value`. + """ + if value is None: + encode_null(output, **kwargs) + elif isinstance(value, bool): + encode_boolean(output, value, **kwargs) + elif isinstance(value, str): + encode_string(output, value, **kwargs) + elif isinstance(value, uuid.UUID): + encode_uuid(output, value, **kwargs) + elif isinstance(value, (bytearray, bytes)): + encode_binary(output, value, **kwargs) + elif isinstance(value, float): + encode_double(output, value, **kwargs) + elif isinstance(value, int): + encode_int(output, value, **kwargs) + elif isinstance(value, datetime): + encode_timestamp(output, value, **kwargs) + elif isinstance(value, list): + encode_list(output, value, **kwargs) + elif isinstance(value, tuple): + encode_described(output, cast(Tuple[Any, Any], value), **kwargs) + elif isinstance(value, dict): + encode_map(output, value, **kwargs) + else: + raise TypeError("Unable to encode unknown value: {}".format(value)) + + +_FIELD_DEFINITIONS = { + FieldDefinition.fields: encode_fields, + FieldDefinition.annotations: encode_annotations, + FieldDefinition.message_id: encode_message_id, + FieldDefinition.app_properties: encode_application_properties, + FieldDefinition.node_properties: encode_node_properties, + FieldDefinition.filter_set: encode_filter_set, +} + +_ENCODE_MAP = { + None: encode_unknown, + AMQPTypes.null: encode_null, + AMQPTypes.boolean: encode_boolean, + AMQPTypes.ubyte: encode_ubyte, + AMQPTypes.byte: encode_byte, + AMQPTypes.ushort: encode_ushort, + AMQPTypes.short: encode_short, + AMQPTypes.uint: encode_uint, + AMQPTypes.int: encode_int, + AMQPTypes.ulong: encode_ulong, + AMQPTypes.long: encode_long, + AMQPTypes.float: encode_float, + AMQPTypes.double: encode_double, + AMQPTypes.timestamp: encode_timestamp, + AMQPTypes.uuid: encode_uuid, + AMQPTypes.binary: encode_binary, + AMQPTypes.string: encode_string, + AMQPTypes.symbol: encode_symbol, + AMQPTypes.list: encode_list, + AMQPTypes.map: encode_map, + AMQPTypes.array: encode_array, + AMQPTypes.described: encode_described, +} + + +def encode_value(output, value, **kwargs): + # type: (bytearray, Any, Any) -> None + try: + cast(Callable, _ENCODE_MAP[value[TYPE]])(output, value[VALUE], **kwargs) + except (KeyError, TypeError): + encode_unknown(output, value, **kwargs) + + +def describe_performative(performative): + # type: (Performative) -> Dict[str, Sequence[Collection[str]]] + body: List[Dict[str, Any]] = [] + for index, value in enumerate(performative): + # TODO: fix mypy + field = performative._definition[index] # type: ignore # pylint: disable=protected-access + if value is None: + body.append({TYPE: AMQPTypes.null, VALUE: None}) + elif field is None: + continue + elif isinstance(field.type, FieldDefinition): + if field.multiple: + body.append( + { + TYPE: AMQPTypes.array, + VALUE: [_FIELD_DEFINITIONS[field.type](v) for v in value], + } + ) + else: + body.append(_FIELD_DEFINITIONS[field.type](value)) + elif isinstance(field.type, ObjDefinition): + body.append(describe_performative(value)) + else: + if field.multiple: + body.append( + { + TYPE: AMQPTypes.array, + VALUE: [{TYPE: field.type, VALUE: v} for v in value], + } + ) + else: + body.append({TYPE: field.type, VALUE: value}) + + return { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: performative._code}, # type: ignore # pylint: disable=protected-access + {TYPE: AMQPTypes.list, VALUE: body}, + ), + } + + +def encode_payload(output, payload): + # type: (bytearray, Message) -> bytes + + if payload[0]: # header + # TODO: Header and Properties encoding can be optimized to + # 1. not encoding trailing None fields + # Possible fix 1: + # header = payload[0] + # header = header[0:max(i for i, v in enumerate(header) if v is not None) + 1] + # Possible fix 2: + # itertools.dropwhile(lambda x: x is None, header[::-1]))[::-1] + # 2. encoding bool without constructor + # Possible fix 3: + # header = list(payload[0]) + # while header[-1] is None: + # del header[-1] + encode_value(output, describe_performative(payload[0])) + + if payload[2]: # message annotations + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000072}, + encode_annotations(payload[2]), + ), + }, + ) + + if payload[3]: # properties + # TODO: Header and Properties encoding can be optimized to + # 1. not encoding trailing None fields + # 2. encoding bool without constructor + encode_value(output, describe_performative(payload[3])) + + if payload[4]: # application properties + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000074}, + encode_application_properties(payload[4]), + ), + }, + ) + + if payload[5]: # data + for item_value in payload[5]: + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000075}, + {TYPE: AMQPTypes.binary, VALUE: item_value}, + ), + }, + ) + + if payload[6]: # sequence + for item_value in payload[6]: + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000076}, + {TYPE: None, VALUE: item_value}, + ), + }, + ) + + if payload[7]: # value + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000077}, + {TYPE: None, VALUE: payload[7]}, + ), + }, + ) + + if payload[8]: # footer + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000078}, + encode_annotations(payload[8]), + ), + }, + ) + + # TODO: + # currently the delivery annotations must be finally encoded instead of being encoded at the 2nd position + # otherwise the event hubs service would ignore the delivery annotations + # -- received message doesn't have it populated + # check with service team? + if payload[1]: # delivery annotations + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000071}, + encode_annotations(payload[1]), + ), + }, + ) + + return output + + +def encode_frame(frame, frame_type=_FRAME_TYPE): + # type: (Optional[Performative], bytes) -> Tuple[bytes, Optional[bytes]] + # TODO: allow passing type specific bytes manually, e.g. Empty Frame needs padding + if frame is None: + size = 8 + header = size.to_bytes(4, "big") + _FRAME_OFFSET + frame_type + return header, None + + frame_description = describe_performative(frame) + frame_data = bytearray() + encode_value(frame_data, frame_description) + if isinstance(frame, performatives.TransferFrame): + frame_data += frame.payload + + size = len(frame_data) + 8 + header = size.to_bytes(4, "big") + _FRAME_OFFSET + frame_type + return header, frame_data diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_message_backcompat.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_message_backcompat.py new file mode 100644 index 0000000000000..e81099b3b7662 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_message_backcompat.py @@ -0,0 +1,258 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# pylint: disable=too-many-lines +from typing import Callable, cast, TYPE_CHECKING +from enum import Enum + +from ._encode import encode_payload +from .utils import get_message_encoded_size +from .error import AMQPError +from .message import Header, Properties + +if TYPE_CHECKING: + from ..amqp._amqp_message import AmqpAnnotatedMessage + +def _encode_property(value): + try: + return value.encode("UTF-8") + except AttributeError: + return value + + +class MessageState(Enum): + WaitingToBeSent = 0 + WaitingForSendAck = 1 + SendComplete = 2 + SendFailed = 3 + ReceivedUnsettled = 4 + ReceivedSettled = 5 + + def __eq__(self, __o: object) -> bool: + try: + return self.value == cast(Enum, __o).value + except AttributeError: + return super().__eq__(__o) + + +class MessageAlreadySettled(Exception): + pass + + +DONE_STATES = (MessageState.SendComplete, MessageState.SendFailed) +RECEIVE_STATES = (MessageState.ReceivedSettled, MessageState.ReceivedUnsettled) +PENDING_STATES = (MessageState.WaitingForSendAck, MessageState.WaitingToBeSent) + + +class LegacyMessage(object): # pylint: disable=too-many-instance-attributes + def __init__(self, message, **kwargs): + self._message: "AmqpAnnotatedMessage" = message + self.state = MessageState.SendComplete + self.idle_time = 0 + self.retries = 0 + self._settler = kwargs.get("settler") + self._encoding = kwargs.get("encoding") + self.delivery_no = kwargs.get("delivery_no") + self.delivery_tag = kwargs.get("delivery_tag") or None + self.on_send_complete = None + self.properties = ( + LegacyMessageProperties(self._message.properties) + if self._message.properties + else None + ) + self.application_properties = ( + self._message.application_properties + if self._message.application_properties and any(self._message.application_properties) + else None + ) + self.annotations = ( + self._message.annotations + if self._message.annotations and any(self._message.annotations) + else None + ) + self.header = ( + LegacyMessageHeader(self._message.header) if self._message.header else None + ) + self.footer = self._message.footer + self.delivery_annotations = self._message.delivery_annotations + if self._settler: + self.state = MessageState.ReceivedUnsettled + elif self.delivery_no: + self.state = MessageState.ReceivedSettled + self._to_outgoing_amqp_message: Callable = kwargs.get( + "to_outgoing_amqp_message" + ) + + def __str__(self): + return str(self._message) + + def _can_settle_message(self): + if self.state not in RECEIVE_STATES: + raise TypeError("Only received messages can be settled.") + if self.settled: + return False + return True + + @property + def settled(self): + if self.state == MessageState.ReceivedUnsettled: + return False + return True + + def get_message_encoded_size(self): + return get_message_encoded_size(self._to_outgoing_amqp_message(self._message)) + + def encode_message(self): + output = bytearray() + # to maintain the same behavior as uamqp, app prop values will not be decoded + self.application_properties = self._message.application_properties.copy() + encode_payload(output, self._to_outgoing_amqp_message(self._message)) + return bytes(output) + + def get_data(self): + return self._message.body + + def gather(self): + if self.state in RECEIVE_STATES: + raise TypeError("Only new messages can be gathered.") + if not self._message: + raise ValueError("Message data already consumed.") + if self.state in DONE_STATES: + raise MessageAlreadySettled() + return [self] + + def get_message(self): + return self._to_outgoing_amqp_message(self._message) + + def accept(self): + if self._can_settle_message(): + self._settler.settle_messages(self.delivery_no, "accepted") + self.state = MessageState.ReceivedSettled + return True + return False + + def reject(self, condition=None, description=None, info=None): + if self._can_settle_message(): + self._settler.settle_messages( + self.delivery_no, + "rejected", + error=AMQPError( + condition=condition, description=description, info=info + ), + ) + self.state = MessageState.ReceivedSettled + return True + return False + + def release(self): + if self._can_settle_message(): + self._settler.settle_messages(self.delivery_no, "released") + self.state = MessageState.ReceivedSettled + return True + return False + + def modify(self, failed, deliverable, annotations=None): + if self._can_settle_message(): + self._settler.settle_messages( + self.delivery_no, + "modified", + delivery_failed=failed, + undeliverable_here=deliverable, + message_annotations=annotations, + ) + self.state = MessageState.ReceivedSettled + return True + return False + + +class LegacyBatchMessage(LegacyMessage): + batch_format = 0x80013700 + max_message_length = 1024 * 1024 + size_offset = 0 + + +class LegacyMessageProperties(object): # pylint: disable=too-many-instance-attributes + def __init__(self, properties): + self.message_id = _encode_property(properties.message_id) + self.user_id = _encode_property(properties.user_id) + self.to = _encode_property(properties.to) + self.subject = _encode_property(properties.subject) + self.reply_to = _encode_property(properties.reply_to) + self.correlation_id = _encode_property(properties.correlation_id) + self.content_type = _encode_property(properties.content_type) + self.content_encoding = _encode_property(properties.content_encoding) + self.absolute_expiry_time = properties.absolute_expiry_time + self.creation_time = properties.creation_time + self.group_id = _encode_property(properties.group_id) + self.group_sequence = properties.group_sequence + self.reply_to_group_id = _encode_property(properties.reply_to_group_id) + + def __str__(self): + return str( + { + "message_id": self.message_id, + "user_id": self.user_id, + "to": self.to, + "subject": self.subject, + "reply_to": self.reply_to, + "correlation_id": self.correlation_id, + "content_type": self.content_type, + "content_encoding": self.content_encoding, + "absolute_expiry_time": self.absolute_expiry_time, + "creation_time": self.creation_time, + "group_id": self.group_id, + "group_sequence": self.group_sequence, + "reply_to_group_id": self.reply_to_group_id, + } + ) + + def get_properties_obj(self): + return Properties( + self.message_id, + self.user_id, + self.to, + self.subject, + self.reply_to, + self.correlation_id, + self.content_type, + self.content_encoding, + self.absolute_expiry_time, + self.creation_time, + self.group_id, + self.group_sequence, + self.reply_to_group_id, + ) + + +class LegacyMessageHeader(object): + def __init__(self, header): + self.delivery_count = header.delivery_count or 0 + self.time_to_live = header.time_to_live + self.first_acquirer = header.first_acquirer + self.durable = header.durable + self.priority = header.priority + + def __str__(self): + return str( + { + "delivery_count": self.delivery_count, + "time_to_live": self.time_to_live, + "first_acquirer": self.first_acquirer, + "durable": self.durable, + "priority": self.priority, + } + ) + + def get_header_obj(self): + # TODO: uamqp returned object has property: `time_to_live`. + # This Header has `ttl`. + return Header( + self.durable, + self.priority, + self.time_to_live, + self.first_acquirer, + self.delivery_count, + ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_platform.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_platform.py new file mode 100644 index 0000000000000..18d91f710041c --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_platform.py @@ -0,0 +1,107 @@ +"""Platform compatibility.""" +# pylint: skip-file + +from __future__ import absolute_import, unicode_literals + +from typing import Tuple, cast +import platform +import re +import struct +import sys + +# Jython does not have this attribute +try: + from socket import SOL_TCP +except ImportError: # pragma: no cover + from socket import IPPROTO_TCP as SOL_TCP # noqa + + +RE_NUM = re.compile(r'(\d+).+') + + +def _linux_version_to_tuple(s): + # type: (str) -> Tuple[int, int, int] + return cast(Tuple[int, int, int], tuple(map(_versionatom, s.split('.')[:3]))) + + +def _versionatom(s): + # type: (str) -> int + if s.isdigit(): + return int(s) + match = RE_NUM.match(s) + return int(match.groups()[0]) if match else 0 + + +# available socket options for TCP level +KNOWN_TCP_OPTS = { + 'TCP_CORK', 'TCP_DEFER_ACCEPT', 'TCP_KEEPCNT', + 'TCP_KEEPIDLE', 'TCP_KEEPINTVL', 'TCP_LINGER2', + 'TCP_MAXSEG', 'TCP_NODELAY', 'TCP_QUICKACK', + 'TCP_SYNCNT', 'TCP_USER_TIMEOUT', 'TCP_WINDOW_CLAMP', +} + +LINUX_VERSION = None +if sys.platform.startswith('linux'): + LINUX_VERSION = _linux_version_to_tuple(platform.release()) + if LINUX_VERSION < (2, 6, 37): + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + + # Windows Subsystem for Linux is an edge-case: the Python socket library + # returns most TCP_* enums, but they aren't actually supported + if platform.release().endswith("Microsoft"): + KNOWN_TCP_OPTS = {'TCP_NODELAY', 'TCP_KEEPIDLE', 'TCP_KEEPINTVL', + 'TCP_KEEPCNT'} + +elif sys.platform.startswith('darwin'): + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + +elif 'bsd' in sys.platform: + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + +# According to MSDN Windows platforms support getsockopt(TCP_MAXSSEG) but not +# setsockopt(TCP_MAXSEG) on IPPROTO_TCP sockets. +elif sys.platform.startswith('win'): + KNOWN_TCP_OPTS = {'TCP_NODELAY'} + +elif sys.platform.startswith('cygwin'): + KNOWN_TCP_OPTS = {'TCP_NODELAY'} + +# illumos does not allow to set the TCP_MAXSEG socket option, +# even if the Oracle documentation says otherwise. +elif sys.platform.startswith('sunos'): + KNOWN_TCP_OPTS.remove('TCP_MAXSEG') + +# aix does not allow to set the TCP_MAXSEG +# or the TCP_USER_TIMEOUT socket options. +elif sys.platform.startswith('aix'): + KNOWN_TCP_OPTS.remove('TCP_MAXSEG') + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + +if sys.version_info < (2, 7, 7): # pragma: no cover + import functools + + def _to_bytes_arg(fun): + @functools.wraps(fun) + def _inner(s, *args, **kwargs): + return fun(s.encode(), *args, **kwargs) + return _inner + + pack = _to_bytes_arg(struct.pack) + pack_into = _to_bytes_arg(struct.pack_into) + unpack = _to_bytes_arg(struct.unpack) + unpack_from = _to_bytes_arg(struct.unpack_from) +else: + pack = struct.pack + pack_into = struct.pack_into + unpack = struct.unpack + unpack_from = struct.unpack_from + +__all__ = [ + 'LINUX_VERSION', + 'SOL_TCP', + 'KNOWN_TCP_OPTS', + 'pack', + 'pack_into', + 'unpack', + 'unpack_from', +] diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py new file mode 100644 index 0000000000000..7ab5697d51dde --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py @@ -0,0 +1,805 @@ +# ------------------------------------------------------------------------- # pylint: disable=file-needs-copyright-header +# This is a fork of the transport.py which was originally written by Barry Pederson and +# maintained by the Celery project: https://github.com/celery/py-amqp. +# +# Copyright (C) 2009 Barry Pederson +# +# The license text can also be found here: +# http://www.opensource.org/licenses/BSD-3-Clause +# +# License +# ======= +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ------------------------------------------------------------------------- + + +from __future__ import absolute_import, unicode_literals + +import errno +import re +import socket +import ssl +import struct +from ssl import SSLError +from contextlib import contextmanager +from io import BytesIO +import logging +from threading import Lock + +import certifi + +from ._platform import KNOWN_TCP_OPTS, SOL_TCP +from ._encode import encode_frame +from ._decode import decode_frame, decode_empty_frame +from .constants import ( + TLS_HEADER_FRAME, + WEBSOCKET_PORT, + TransportType, + AMQP_WS_SUBPROTOCOL, + TIMEOUT_INTERVAL, + WS_TIMEOUT_INTERVAL, + READ_TIMEOUT_INTERVAL, +) +from .error import AuthenticationException, ErrorCondition + + +try: + import fcntl +except ImportError: # pragma: no cover + fcntl = None # type: ignore # noqa + +def set_cloexec(fd, cloexec): # noqa + """Set flag to close fd after exec.""" + if fcntl is None: + return + try: + FD_CLOEXEC = fcntl.FD_CLOEXEC + except AttributeError: + raise NotImplementedError( + "close-on-exec flag not supported on this platform", + ) + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + if cloexec: + flags |= FD_CLOEXEC + else: + flags &= ~FD_CLOEXEC + return fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + +_LOGGER = logging.getLogger(__name__) +_UNAVAIL = {errno.EAGAIN, errno.EINTR, errno.ENOENT, errno.EWOULDBLOCK} + +AMQP_PORT = 5672 +AMQPS_PORT = 5671 +AMQP_FRAME = memoryview(b"AMQP") +EMPTY_BUFFER = bytes() +SIGNED_INT_MAX = 0x7FFFFFFF + +# Match things like: [fe80::1]:5432, from RFC 2732 +IPV6_LITERAL = re.compile(r"\[([\.0-9a-f:]+)\](?::(\d+))?") + +DEFAULT_SOCKET_SETTINGS = { + "TCP_NODELAY": 1, + "TCP_USER_TIMEOUT": 1000, + "TCP_KEEPIDLE": 60, + "TCP_KEEPINTVL": 10, + "TCP_KEEPCNT": 9, +} + + +def get_errno(exc): + """Get exception errno (if set). + + Notes: + :exc:`socket.error` and :exc:`IOError` first got + the ``.errno`` attribute in Py2.7. + """ + try: + return exc.errno + except AttributeError: + try: + # e.args = (errno, reason) + if isinstance(exc.args, tuple) and len(exc.args) == 2: + return exc.args[0] + except AttributeError: + pass + return 0 + + +# TODO: fails when host = hostname:port/path. fix +def to_host_port(host, port=AMQP_PORT): + """Convert hostname:port string to host, port tuple.""" + m = IPV6_LITERAL.match(host) + if m: + host = m.group(1) + if m.group(2): + port = int(m.group(2)) + else: + if ":" in host: + host, port = host.rsplit(":", 1) + port = int(port) + return host, port + + +class UnexpectedFrame(Exception): + pass + + +class _AbstractTransport(object): # pylint: disable=too-many-instance-attributes + """Common superclass for TCP and SSL transports.""" + + def __init__( + self, + host, + *, + port=AMQP_PORT, + connect_timeout=None, + read_timeout=None, + socket_settings=None, + raise_on_initial_eintr=True, + **kwargs + ): + self._quick_recv = None + self.connected = False + self.sock = None + self.raise_on_initial_eintr = raise_on_initial_eintr + self._read_buffer = BytesIO() + self.host, self.port = to_host_port(host, port) + self.network_trace_params = kwargs.get('network_trace_params') + + self.connect_timeout = connect_timeout or TIMEOUT_INTERVAL + self.read_timeout = read_timeout or READ_TIMEOUT_INTERVAL + self.socket_settings = socket_settings + self.socket_lock = Lock() + + def connect(self): + try: + # are we already connected? + if self.connected: + return + self._connect(self.host, self.port, self.connect_timeout) + self._init_socket( + self.socket_settings, + self.read_timeout, + ) + # we've sent the banner; signal connect + # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner + # has _not_ been sent + self.connected = True + except (OSError, IOError, SSLError) as e: + _LOGGER.info("Transport connection failed: %r", e, extra=self.network_trace_params) + # if not fully connected, close socket, and reraise error + if self.sock and not self.connected: + self.sock.close() + self.sock = None + raise + + @contextmanager + def block_with_timeout(self, timeout): + if timeout is None: + yield self.sock + else: + sock = self.sock + prev = sock.gettimeout() + if prev != timeout: + sock.settimeout(timeout) + try: + yield self.sock + except SSLError as exc: + if "timed out" in str(exc): + # http://bugs.python.org/issue10272 + raise socket.timeout() + if "The operation did not complete" in str(exc): + # Non-blocking SSL sockets can throw SSLError + raise socket.timeout() + raise + except socket.error as exc: + if get_errno(exc) == errno.EWOULDBLOCK: + raise socket.timeout() + raise + finally: + if timeout != prev: + sock.settimeout(prev) + + @contextmanager + def block(self): + bocking_timeout = None + sock = self.sock + prev = sock.gettimeout() + if prev != bocking_timeout: + sock.settimeout(bocking_timeout) + try: + yield self.sock + except SSLError as exc: + if "timed out" in str(exc): + # http://bugs.python.org/issue10272 + raise socket.timeout() + if "The operation did not complete" in str(exc): + # Non-blocking SSL sockets can throw SSLError + raise socket.timeout() + raise + except socket.error as exc: + if get_errno(exc) == errno.EWOULDBLOCK: + raise socket.timeout() + raise + finally: + if bocking_timeout != prev: + sock.settimeout(prev) + + @contextmanager + def non_blocking(self): + non_bocking_timeout = 0.0 + sock = self.sock + prev = sock.gettimeout() + if prev != non_bocking_timeout: + sock.settimeout(non_bocking_timeout) + try: + yield self.sock + except SSLError as exc: + if "timed out" in str(exc): + # http://bugs.python.org/issue10272 + raise socket.timeout() + if "The operation did not complete" in str(exc): + # Non-blocking SSL sockets can throw SSLError + raise socket.timeout() + raise + except socket.error as exc: + if get_errno(exc) == errno.EWOULDBLOCK: + raise socket.timeout() + raise + finally: + if non_bocking_timeout != prev: + sock.settimeout(prev) + + def _connect(self, host, port, timeout): + e = None + + # Below we are trying to avoid additional DNS requests for AAAA if A + # succeeds. This helps a lot in case when a hostname has an IPv4 entry + # in /etc/hosts but not IPv6. Without the (arguably somewhat twisted) + # logic below, getaddrinfo would attempt to resolve the hostname for + # both IP versions, which would make the resolver talk to configured + # DNS servers. If those servers are for some reason not available + # during resolution attempt (either because of system misconfiguration, + # or network connectivity problem), resolution process locks the + # _connect call for extended time. + addr_types = (socket.AF_INET, socket.AF_INET6) + addr_types_num = len(addr_types) + for n, family in enumerate(addr_types): + # first, resolve the address for a single address family + try: + entries = socket.getaddrinfo( + host, port, family, socket.SOCK_STREAM, SOL_TCP + ) + entries_num = len(entries) + except socket.gaierror: + # we may have depleted all our options + if n + 1 >= addr_types_num: + # if getaddrinfo succeeded before for another address + # family, reraise the previous socket.error since it's more + # relevant to users + raise e if e is not None else socket.error( + "failed to resolve broker hostname" + ) + continue # pragma: no cover + + # now that we have address(es) for the hostname, connect to broker + for i, res in enumerate(entries): + af, socktype, proto, _, sa = res + try: + self.sock = socket.socket(af, socktype, proto) + try: + set_cloexec(self.sock, True) + except NotImplementedError: + pass + self.sock.settimeout(timeout) + self.sock.connect(sa) + except socket.error as ex: + e = ex + if self.sock is not None: + self.sock.close() + self.sock = None + # we may have depleted all our options + if i + 1 >= entries_num and n + 1 >= addr_types_num: + raise + else: + # hurray, we established connection + return + + def _init_socket(self, socket_settings, read_timeout): + self.sock.settimeout(None) # set socket back to blocking mode + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self._set_socket_options(socket_settings) + + # set socket timeouts + # for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout), + # (socket.SO_RCVTIMEO, read_timeout)): + # if interval is not None: + # sec = int(interval) + # usec = int((interval - sec) * 1000000) + # self.sock.setsockopt( + # socket.SOL_SOCKET, timeout, + # pack('ll', sec, usec), + # ) + self._setup_transport() + # TODO: a greater timeout value is needed in long distance communication + # we should either figure out a reasonable value error/dynamically adjust the timeout + # 0.2 second is enough for perf analysis + self.sock.settimeout(read_timeout) # set socket back to non-blocking mode + + def _get_tcp_socket_defaults(self, sock): # pylint: disable=no-self-use + tcp_opts = {} + for opt in KNOWN_TCP_OPTS: + enum = None + if opt == "TCP_USER_TIMEOUT": + try: + from socket import TCP_USER_TIMEOUT as enum + except ImportError: + # should be in Python 3.6+ on Linux. + enum = 18 + elif hasattr(socket, opt): + enum = getattr(socket, opt) + + if enum: + if opt in DEFAULT_SOCKET_SETTINGS: + tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt] + elif hasattr(socket, opt): + tcp_opts[enum] = sock.getsockopt(SOL_TCP, getattr(socket, opt)) + return tcp_opts + + def _set_socket_options(self, socket_settings): + tcp_opts = self._get_tcp_socket_defaults(self.sock) + if socket_settings: + tcp_opts.update(socket_settings) + for opt, val in tcp_opts.items(): + self.sock.setsockopt(SOL_TCP, opt, val) + + def _read(self, n, initial=False, buffer=None, _errnos=None): + """Read exactly n bytes from the peer.""" + raise NotImplementedError("Must be overriden in subclass") + + def _setup_transport(self): + """Do any additional initialization of the class.""" + + def _shutdown_transport(self): + """Do any preliminary work in shutting down the connection.""" + + def _write(self, s): + """Completely write a string to the peer.""" + raise NotImplementedError("Must be overriden in subclass") + + def close(self): + with self.socket_lock: + if self.sock is not None: + self._shutdown_transport() + # Call shutdown first to make sure that pending messages + # reach the AMQP broker if the program exits after + # calling this method. + try: + self.sock.shutdown(socket.SHUT_RDWR) + except Exception as exc: # pylint: disable=broad-except + # TODO: shutdown could raise OSError, Transport endpoint is not connected if the endpoint is already + # disconnected. can we safely ignore the errors since the close operation is initiated by us. + _LOGGER.debug( + "Transport endpoint is already disconnected: %r", + exc, + extra=self.network_trace_params + ) + self.sock.close() + self.sock = None + self.connected = False + + def read(self, verify_frame_type=0): + with self.socket_lock: + read = self._read + read_frame_buffer = BytesIO() + try: + frame_header = memoryview(bytearray(8)) + read_frame_buffer.write(read(8, buffer=frame_header, initial=True)) + + channel = struct.unpack(">H", frame_header[6:])[0] + size = frame_header[0:4] + if size == AMQP_FRAME: # Empty frame or AMQP header negotiation TODO + return frame_header, channel, None + size = struct.unpack(">I", size)[0] + offset = frame_header[4] + frame_type = frame_header[5] + if verify_frame_type is not None and frame_type != verify_frame_type: + _LOGGER.debug( + "Received invalid frame type: %r, expected: %r", + frame_type, + verify_frame_type, + extra=self.network_trace_params + ) + raise ValueError( + f"Received invalid frame type: {frame_type}, expected: {verify_frame_type}" + ) + + # >I is an unsigned int, but the argument to sock.recv is signed, + # so we know the size can be at most 2 * SIGNED_INT_MAX + payload_size = size - len(frame_header) + payload = memoryview(bytearray(payload_size)) + if size > SIGNED_INT_MAX: + read_frame_buffer.write(read(SIGNED_INT_MAX, buffer=payload)) + read_frame_buffer.write( + read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:]) + ) + else: + read_frame_buffer.write(read(payload_size, buffer=payload)) + except (socket.timeout, TimeoutError): + read_frame_buffer.write(self._read_buffer.getvalue()) + self._read_buffer = read_frame_buffer + self._read_buffer.seek(0) + raise + except (OSError, IOError, SSLError, socket.error) as exc: + # Don't disconnect for ssl read time outs + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and "timed out" in str(exc): + raise socket.timeout() + if get_errno(exc) not in _UNAVAIL: + self.connected = False + _LOGGER.debug("Transport read failed: %r", exc, extra=self.network_trace_params) + raise + offset -= 2 + return frame_header, channel, payload[offset:] + + def write(self, s): + with self.socket_lock: + try: + self._write(s) + except socket.timeout: + raise + except (OSError, IOError, socket.error) as exc: + _LOGGER.debug("Transport write failed: %r", exc, extra=self.network_trace_params) + if get_errno(exc) not in _UNAVAIL: + self.connected = False + raise + + def receive_frame(self, **kwargs): + try: + header, channel, payload = self.read(**kwargs) + if not payload: + decoded = decode_empty_frame(header) + else: + decoded = decode_frame(payload) + return channel, decoded + except (socket.timeout, TimeoutError): + return None, None + + def send_frame(self, channel, frame, **kwargs): + header, performative = encode_frame(frame, **kwargs) + if performative is None: + data = header + else: + encoded_channel = struct.pack(">H", channel) + data = header + encoded_channel + performative + self.write(data) + + def negotiate(self): + pass + + +class SSLTransport(_AbstractTransport): + """Transport that works over SSL.""" + + def __init__( + self, host, *, port=AMQPS_PORT, connect_timeout=None, ssl_opts=None, **kwargs + ): + self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {} + self._read_buffer = BytesIO() + super(SSLTransport, self).__init__( + host, port=port, connect_timeout=connect_timeout, **kwargs + ) + + def _setup_transport(self): + """Wrap the socket in an SSL object.""" + self.sock = self._wrap_socket(self.sock, **self.sslopts) + self.sock.do_handshake() + self._quick_recv = self.sock.recv + + def _wrap_socket(self, sock, context=None, **sslopts): + if context: + return self._wrap_context(sock, sslopts, **context) + return self._wrap_socket_sni(sock, **sslopts) + + def _wrap_context( # pylint: disable=no-self-use + self, sock, sslopts, check_hostname=None, **ctx_options + ): + ctx = ssl.create_default_context(**ctx_options) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(cafile=certifi.where()) + ctx.check_hostname = check_hostname + return ctx.wrap_socket(sock, **sslopts) + + def _wrap_socket_sni( # pylint: disable=no-self-use + self, + sock, + keyfile=None, + certfile=None, + server_side=False, + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=None, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + server_hostname=None, + ciphers=None, + ssl_version=None, + ): + """Socket wrap with SNI headers. + + Default `ssl.wrap_socket` method augmented with support for + setting the server_hostname field required for SNI hostname header + """ + # Setup the right SSL version; default to optimal versions across + # ssl implementations + if ssl_version is None: + ssl_version = ssl.PROTOCOL_TLS + + opts = { + "sock": sock, + "keyfile": keyfile, + "certfile": certfile, + "server_side": server_side, + "cert_reqs": cert_reqs, + "ca_certs": ca_certs, + "do_handshake_on_connect": do_handshake_on_connect, + "suppress_ragged_eofs": suppress_ragged_eofs, + "ciphers": ciphers, + #'ssl_version': ssl_version + } + + # TODO: We need to refactor this. + try: + sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method + except FileNotFoundError as exc: + # FileNotFoundError does not have missing filename info, so adding it below. + # Assuming that this must be ca_certs, since this is the only file path that + # users can pass in (`connection_verify` in the EH/SB clients) through opts above. + # For uamqp exception parity. Remove later when resolving issue #27128. + exc.filename = {"ca_certs": ca_certs} + raise exc + # Set SNI headers if supported + if ( + (server_hostname is not None) + and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) + and (hasattr(ssl, "SSLContext")) + ): + context = ssl.SSLContext(opts["ssl_version"]) + context.verify_mode = cert_reqs + if cert_reqs != ssl.CERT_NONE: + context.check_hostname = True + if (certfile is not None) and (keyfile is not None): + context.load_cert_chain(certfile, keyfile) + sock = context.wrap_socket(sock, server_hostname=server_hostname) + return sock + + def _shutdown_transport(self): + """Unwrap a SSL socket, so we can call shutdown().""" + if self.sock is not None: + try: + self.sock = self.sock.unwrap() + except OSError: + pass + + def _read( + self, + n, + initial=False, + buffer=None, + _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR), + ): + # According to SSL_read(3), it can at most return 16kb of data. + # Thus, we use an internal read buffer like TCPTransport._read + # to get the exact number of bytes wanted. + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + toread = n - nbytes + length += nbytes + try: + while toread: + try: + nbytes = self.sock.recv_into(view[length:]) + except socket.error as exc: + # ssl.sock.read may cause a SSLerror without errno + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and "timed out" in str(exc): + raise socket.timeout() + # ssl.sock.read may cause ENOENT if the + # operation couldn't be performed (Issue celery#1414). + if exc.errno in _errnos: + if initial and self.raise_on_initial_eintr: + raise socket.timeout() + continue + raise + if not nbytes: + raise IOError("Server unexpectedly closed connection") + + length += nbytes + toread -= nbytes + except: # noqa + self._read_buffer = BytesIO(view[:length]) + raise + return view + + def _write(self, s): + """Write a string out to the SSL socket fully.""" + write = self.sock.send + while s: + try: + n = write(s) + except ValueError: + # AG: sock._sslobj might become null in the meantime if the + # remote connection has hung up. + # In python 3.4, a ValueError is raised is self._sslobj is + # None. + n = 0 + if not n: + raise IOError("Socket closed") + s = s[n:] + + def negotiate(self): + with self.block(): + self.write(TLS_HEADER_FRAME) + _, returned_header = self.receive_frame(verify_frame_type=None) + if returned_header[1] == TLS_HEADER_FRAME: + raise ValueError( + f"""Mismatching TLS header protocol. Expected: {TLS_HEADER_FRAME!r},""" + """received: {returned_header[1]!r}""" + ) + + +def Transport(host, transport_type, connect_timeout=None, ssl_opts=True, **kwargs): + """Create transport. + + Given a few parameters from the Connection constructor, + select and create a subclass of _AbstractTransport. + """ + if transport_type == TransportType.AmqpOverWebsocket: + transport = WebSocketTransport + else: + transport = SSLTransport + return transport(host, connect_timeout=connect_timeout, ssl_opts=ssl_opts, **kwargs) + + +class WebSocketTransport(_AbstractTransport): + def __init__( + self, + host, + *, + port=WEBSOCKET_PORT, + connect_timeout=None, + ssl_opts=None, + **kwargs, + ): + self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {} + self._connect_timeout = connect_timeout or WS_TIMEOUT_INTERVAL + self._host = host + self._custom_endpoint = kwargs.get("custom_endpoint") + super().__init__(host, port=port, connect_timeout=connect_timeout, **kwargs) + self.ws = None + self._http_proxy = kwargs.get("http_proxy", None) + + def connect(self): + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None + if self._http_proxy: + http_proxy_host = self._http_proxy["proxy_hostname"] + http_proxy_port = self._http_proxy["proxy_port"] + username = self._http_proxy.get("username", None) + password = self._http_proxy.get("password", None) + if username or password: + http_proxy_auth = (username, password) + try: + from websocket import ( + create_connection, + WebSocketAddressException, + WebSocketTimeoutException, + WebSocketConnectionClosedException + ) + except ImportError: + raise ImportError( + "Please install websocket-client library to use sync websocket transport." + ) + try: + self.ws = create_connection( + url="wss://{}".format(self._custom_endpoint or self._host), + subprotocols=[AMQP_WS_SUBPROTOCOL], + timeout=self._connect_timeout, + skip_utf8_validation=True, + sslopt=self.sslopts, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_proxy_auth=http_proxy_auth, + ) + except WebSocketAddressException as exc: + raise AuthenticationException( + ErrorCondition.ClientError, + description="Failed to authenticate the connection due to exception: " + str(exc), + error=exc, + ) + # TODO: resolve pylance error when type: ignore is removed below, issue #22051 + except (WebSocketTimeoutException, SSLError, WebSocketConnectionClosedException) as exc: # type: ignore + self.close() + raise ConnectionError("Websocket failed to establish connection: %r" % exc) from exc + except (OSError, IOError, SSLError) as e: + _LOGGER.info("Websocket connection failed: %r", e, extra=self.network_trace_params) + self.close() + raise + + def _read(self, n, initial=False, buffer=None, _errnos=None): # pylint: disable=unused-argument + """Read exactly n bytes from the peer.""" + from websocket import WebSocketTimeoutException + try: + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + n -= nbytes + try: + while n: + data = self.ws.recv() + if len(data) <= n: + view[length : length + len(data)] = data + n -= len(data) + length += len(data) + else: + view[length : length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 + return view + except AttributeError: + raise IOError("Websocket connection has already been closed.") + except WebSocketTimeoutException as wte: + raise TimeoutError('Websocket receive timed out (%s)' % wte) + except: + self._read_buffer = BytesIO(view[:length]) + raise + + def close(self): + with self.socket_lock: + if self.ws: + self._shutdown_transport() + self.ws = None + + def _shutdown_transport(self): + # TODO Sync and Async close functions named differently + """Do any preliminary work in shutting down the connection.""" + if self.ws: + self.ws.close() + + def _write(self, s): + """Completely write a string to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + from websocket import WebSocketConnectionClosedException, WebSocketTimeoutException + try: + self.ws.send_binary(s) + except AttributeError: + raise IOError("Websocket connection has already been closed.") + except WebSocketTimeoutException as e: + raise socket.timeout('Websocket send timed out (%s)' % e) + except (WebSocketConnectionClosedException, SSLError) as e: + raise ConnectionError('Websocket disconnected: %r' % e) + \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/__init__.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/__init__.py new file mode 100644 index 0000000000000..bcf047fdb428f --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/__init__.py @@ -0,0 +1,35 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from ._connection_async import Connection, ConnectionState +from ._link_async import Link, LinkState +from ..constants import LinkDeliverySettleReason +from ._receiver_async import ReceiverLink +from ._sasl_async import SASLPlainCredential, SASLTransport +from ._sender_async import SenderLink +from ._session_async import Session, SessionState +from ._transport_async import AsyncTransport +from ._client_async import AMQPClientAsync, ReceiveClientAsync, SendClientAsync +from ._authentication_async import SASTokenAuthAsync + +__all__ = [ + "Connection", + "ConnectionState", + "Link", + "LinkDeliverySettleReason", + "LinkState", + "ReceiverLink", + "SASLPlainCredential", + "SASLTransport", + "SenderLink", + "Session", + "SessionState", + "AsyncTransport", + "AMQPClientAsync", + "ReceiveClientAsync", + "SendClientAsync", + "SASTokenAuthAsync", +] diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_authentication_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_authentication_async.py new file mode 100644 index 0000000000000..f6b68b277d6dc --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_authentication_async.py @@ -0,0 +1,70 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#------------------------------------------------------------------------- +from functools import partial + +from ..authentication import ( + _generate_sas_access_token, + SASTokenAuth, + JWTTokenAuth +) +from ..constants import AUTH_DEFAULT_EXPIRATION_SECONDS + + +async def _generate_sas_token_async(auth_uri, sas_name, sas_key, expiry_in=AUTH_DEFAULT_EXPIRATION_SECONDS): + return _generate_sas_access_token(auth_uri, sas_name, sas_key, expiry_in=expiry_in) + + +class JWTTokenAuthAsync(JWTTokenAuth): + # TODO: + # 1. naming decision, suffix with Auth vs Credential + ... + + +class SASTokenAuthAsync(SASTokenAuth): + # TODO: + # 1. naming decision, suffix with Auth vs Credential + def __init__( + self, + uri, + audience, + username, + password, + **kwargs + ): + """ + CBS authentication using SAS tokens. + + :param uri: The AMQP endpoint URI. This must be provided as + a decoded string. + :type uri: str + :param audience: The token audience field. For SAS tokens + this is usually the URI. + :type audience: str + :param username: The SAS token username, also referred to as the key + name or policy name. This can optionally be encoded into the URI. + :type username: str + :param password: The SAS token password, also referred to as the key. + This can optionally be encoded into the URI. + :type password: str + :param expires_in: The total remaining seconds until the token + expires. + :type expires_in: int + :param expires_on: The timestamp at which the SAS token will expire + formatted as seconds since epoch. + :type expires_on: float + :param token_type: The type field of the token request. + Default value is `"servicebus.windows.net:sastoken"`. + :type token_type: str + + """ + super(SASTokenAuthAsync, self).__init__( + uri, + audience, + username, + password, + **kwargs + ) + self.get_token = partial(_generate_sas_token_async, uri, username, password, self.expires_in) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_cbs_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_cbs_async.py new file mode 100644 index 0000000000000..b0bc86d2f3925 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_cbs_async.py @@ -0,0 +1,260 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import logging +from datetime import datetime +from typing import Optional + +from ..utils import utc_now, utc_from_timestamp +from ._management_link_async import ManagementLink +from ..message import Message, Properties +from ..error import AuthenticationException, ErrorCondition, TokenAuthFailure, TokenExpired +from ..constants import ( + CbsState, + CbsAuthState, + CBS_PUT_TOKEN, + CBS_EXPIRATION, + CBS_NAME, + CBS_TYPE, + CBS_OPERATION, + ManagementExecuteOperationResult, + ManagementOpenResult, +) +from ..cbs import check_put_timeout_status, check_expiration_and_refresh_status + +_LOGGER = logging.getLogger(__name__) + + +class CBSAuthenticator(object): # pylint:disable=too-many-instance-attributes + def __init__(self, session, auth, **kwargs): + self._session = session + self._connection = self._session._connection + self._mgmt_link = self._session.create_request_response_link_pair( + endpoint="$cbs", + on_amqp_management_open_complete=self._on_amqp_management_open_complete, + on_amqp_management_error=self._on_amqp_management_error, + status_code_field=b"status-code", + status_description_field=b"status-description", + ) # type: ManagementLink + + # if not auth.get_token or not asyncio.iscoroutinefunction(auth.get_token): + # raise ValueError("get_token must be a coroutine object.") + + self._auth = auth + self._encoding = 'UTF-8' + self._auth_timeout = kwargs.get('auth_timeout') + self._token_put_time = None + self._expires_on = None + self._token = None + self._refresh_window = None + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } + + self._token_status_code = None + self._token_status_description = None + + self.state = CbsState.CLOSED + self.auth_state = CbsAuthState.IDLE + + async def _put_token( + self, token: str, token_type: str, audience: str, expires_on: Optional[datetime] = None + ) -> None: + message = Message( # type: ignore # TODO: missing positional args header, etc. + value=token, + properties=Properties(message_id=self._mgmt_link.next_message_id), # type: ignore + application_properties={ + CBS_NAME: audience, + CBS_OPERATION: CBS_PUT_TOKEN, + CBS_TYPE: token_type, + CBS_EXPIRATION: expires_on, + }, + ) + await self._mgmt_link.execute_operation( + message, + self._on_execute_operation_complete, + timeout=self._auth_timeout, + operation=CBS_PUT_TOKEN, + type=token_type, + ) + self._mgmt_link.next_message_id += 1 + + async def _on_amqp_management_open_complete(self, management_open_result): + if self.state in (CbsState.CLOSED, CbsState.ERROR): + _LOGGER.debug( + "CSB with status: %r encounters unexpected AMQP management open complete.", + self.state, + extra=self._network_trace_params + ) + elif self.state == CbsState.OPEN: + self.state = CbsState.ERROR + _LOGGER.info( + "Unexpected AMQP management open complete in OPEN, CBS error occurred.", + extra=self._network_trace_params + ) + elif self.state == CbsState.OPENING: + self.state = CbsState.OPEN if management_open_result == ManagementOpenResult.OK else CbsState.CLOSED + _LOGGER.info( + "CBS completed opening with status: %r", + management_open_result, + extra=self._network_trace_params + ) + + async def _on_amqp_management_error(self): + if self.state == CbsState.CLOSED: + _LOGGER.debug("Unexpected AMQP error in CLOSED state.", extra=self._network_trace_params) + elif self.state == CbsState.OPENING: + self.state = CbsState.ERROR + await self._mgmt_link.close() + _LOGGER.info( + "CBS failed to open with status: %r", + ManagementOpenResult.ERROR, + extra=self._network_trace_params + ) + elif self.state == CbsState.OPEN: + self.state = CbsState.ERROR + _LOGGER.info("CBS error occurred.", extra=self._network_trace_params) + + async def _on_execute_operation_complete( + self, execute_operation_result, status_code, status_description, _, error_condition=None + ): + if error_condition: + _LOGGER.info( + "CBS Put token error: %r", + error_condition, + extra=self._network_trace_params + ) + self.auth_state = CbsAuthState.ERROR + return + _LOGGER.debug( + "CBS Put token result (%r), status code: %s, status_description: %s.", + execute_operation_result, + status_code, + status_description, + extra=self._network_trace_params + ) + self._token_status_code = status_code + self._token_status_description = status_description + + if execute_operation_result == ManagementExecuteOperationResult.OK: + self.auth_state = CbsAuthState.OK + elif execute_operation_result == ManagementExecuteOperationResult.ERROR: + self.auth_state = CbsAuthState.ERROR + # put-token-message sending failure, rejected + self._token_status_code = 0 + self._token_status_description = "Auth message has been rejected." + elif execute_operation_result == ManagementExecuteOperationResult.FAILED_BAD_STATUS: + self.auth_state = CbsAuthState.ERROR + + async def _update_status(self): + if self.auth_state == CbsAuthState.OK or self.auth_state == CbsAuthState.REFRESH_REQUIRED: + is_expired, is_refresh_required = check_expiration_and_refresh_status( + self._expires_on, self._refresh_window + ) # pylint:disable=line-too-long + _LOGGER.debug( + "CBS status check: state == %r, expired == %r, refresh required == %r", + self.auth_state, + is_expired, + is_refresh_required, + extra=self._network_trace_params + ) + if is_expired: + self.auth_state = CbsAuthState.EXPIRED + elif is_refresh_required: + self.auth_state = CbsAuthState.REFRESH_REQUIRED + elif self.auth_state == CbsAuthState.IN_PROGRESS: + _LOGGER.debug( + "CBS update in progress. Token put time: %r", + self._token_put_time, + extra=self._network_trace_params + ) + put_timeout = check_put_timeout_status(self._auth_timeout, self._token_put_time) + if put_timeout: + self.auth_state = CbsAuthState.TIMEOUT + + async def _cbs_link_ready(self): + if self.state == CbsState.OPEN: + return True + if self.state != CbsState.OPEN: + return False + if self.state in (CbsState.CLOSED, CbsState.ERROR): + raise TokenAuthFailure( + status_code=ErrorCondition.ClientError, + status_description="CBS authentication link is in broken status, please recreate the cbs link.", + ) + + async def open(self): + self.state = CbsState.OPENING + await self._mgmt_link.open() + + async def close(self): + await self._mgmt_link.close() + self.state = CbsState.CLOSED + + async def update_token(self): + self.auth_state = CbsAuthState.IN_PROGRESS + access_token = await self._auth.get_token() + if not access_token: + _LOGGER.info( + "Token refresh function received an empty token object.", + extra=self._network_trace_params + ) + elif not access_token.token: + _LOGGER.info( + "Token refresh function received an empty token.", + extra=self._network_trace_params + ) + self._expires_on = access_token.expires_on + expires_in = self._expires_on - int(utc_now().timestamp()) + self._refresh_window = int(float(expires_in) * 0.1) + try: + self._token = access_token.token.decode() + except AttributeError: + self._token = access_token.token + try: + token_type = self._auth.token_type.decode() + except AttributeError: + token_type = self._auth.token_type + + self._token_put_time = int(utc_now().timestamp()) + await self._put_token( + self._token, token_type, self._auth.audience, utc_from_timestamp(self._expires_on) + ) + + async def handle_token(self): + if not await self._cbs_link_ready(): + return False + await self._update_status() + if self.auth_state == CbsAuthState.IDLE: + await self.update_token() + return False + if self.auth_state == CbsAuthState.IN_PROGRESS: + return False + if self.auth_state == CbsAuthState.OK: + return True + if self.auth_state == CbsAuthState.REFRESH_REQUIRED: + _LOGGER.info( + "Token will expire soon - attempting to refresh.", + extra=self._network_trace_params + ) + await self.update_token() + return False + if self.auth_state == CbsAuthState.FAILURE: + raise AuthenticationException( + condition=ErrorCondition.InternalError, description="Failed to open CBS authentication link." + ) + if self.auth_state == CbsAuthState.ERROR: + raise TokenAuthFailure( + self._token_status_code, + self._token_status_description, + encoding=self._encoding, # TODO: drop off all the encodings + ) + if self.auth_state == CbsAuthState.TIMEOUT: + raise TimeoutError("Authentication attempt timed-out.") + if self.auth_state == CbsAuthState.EXPIRED: + raise TokenExpired(condition=ErrorCondition.InternalError, description="CBS Authentication Expired.") diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py new file mode 100644 index 0000000000000..0bbc3f8bf33b8 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py @@ -0,0 +1,965 @@ +#------------------------------------------------------------------------- # pylint: disable=client-suffix-needed +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- +# TODO: Check types of kwargs (issue exists for this) +import asyncio +import logging +import time +import queue +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union, overload, cast +from typing_extensions import Literal +import certifi + +from ..outcomes import Accepted, Modified, Received, Rejected, Released +from ._connection_async import Connection +from ._management_operation_async import ManagementOperation +from ._cbs_async import CBSAuthenticator +from ..client import ( + AMQPClient as AMQPClientSync, + ReceiveClient as ReceiveClientSync, + SendClient as SendClientSync, + Outcomes +) +from ..message import _MessageDelivery +from ..constants import ( + MessageDeliveryState, + SEND_DISPOSITION_ACCEPT, + SEND_DISPOSITION_REJECT, + LinkDeliverySettleReason, + MESSAGE_DELIVERY_DONE_STATES, + AUTH_TYPE_CBS, +) +from ..error import ( + AMQPError, + ErrorCondition, + AMQPException, + MessageException +) +from ..constants import LinkState + +_logger = logging.getLogger(__name__) + + +class AMQPClientAsync(AMQPClientSync): + """An asynchronous AMQP client. + + :param hostname: The AMQP endpoint to connect to. + :type hostname: str + :keyword auth: Authentication for the connection. This should be one of the following: + - pyamqp.authentication.SASLAnonymous + - pyamqp.authentication.SASLPlain + - pyamqp.authentication.SASTokenAuth + - pyamqp.authentication.JWTTokenAuth + If no authentication is supplied, SASLAnnoymous will be used by default. + :paramtype auth: ~pyamqp.authentication + :keyword client_name: The name for the client, also known as the Container ID. + If no name is provided, a random GUID will be used. + :paramtype client_name: str or bytes + :keyword network_trace: Whether to turn on network trace logs. If `True`, trace logs + will be logged at INFO level. Default is `False`. + :paramtype network_trace: bool + :keyword retry_policy: A policy for parsing errors on link, connection and message + disposition to determine whether the error should be retryable. + :paramtype retry_policy: ~pyamqp.error.RetryPolicy + :keyword keep_alive_interval: If set, a thread will be started to keep the connection + alive during periods of user inactivity. The value will determine how long the + thread will sleep (in seconds) between pinging the connection. If 0 or None, no + thread will be started. + :paramtype keep_alive_interval: int + :keyword max_frame_size: Maximum AMQP frame size. Default is 63488 bytes. + :paramtype max_frame_size: int + :keyword channel_max: Maximum number of Session channels in the Connection. + :paramtype channel_max: int + :keyword idle_timeout: Timeout in seconds after which the Connection will close + if there is no further activity. + :paramtype idle_timeout: int + :keyword auth_timeout: Timeout in seconds for CBS authentication. Otherwise this value will be ignored. + Default value is 60s. + :paramtype auth_timeout: int + :keyword properties: Connection properties. + :paramtype properties: dict[str, any] + :keyword remote_idle_timeout_empty_frame_send_ratio: Ratio of empty frames to + idle time for Connections with no activity. Value must be between + 0.0 and 1.0 inclusive. Default is 0.5. + :paramtype remote_idle_timeout_empty_frame_send_ratio: float + :keyword incoming_window: The size of the allowed window for incoming messages. + :paramtype incoming_window: int + :keyword outgoing_window: The size of the allowed window for outgoing messages. + :paramtype outgoing_window: int + :keyword handle_max: The maximum number of concurrent link handles. + :paramtype handle_max: int + :keyword on_attach: A callback function to be run on receipt of an ATTACH frame. + The function must take 4 arguments: source, target, properties and error. + :paramtype on_attach: func[ + ~pyamqp.endpoint.Source, ~pyamqp.endpoint.Target, dict, ~pyamqp.error.AMQPConnectionError] + :keyword send_settle_mode: The mode by which to settle message send + operations. If set to `Unsettled`, the client will wait for a confirmation + from the service that the message was successfully sent. If set to 'Settled', + the client will not wait for confirmation and assume success. + :paramtype send_settle_mode: ~pyamqp.constants.SenderSettleMode + :keyword receive_settle_mode: The mode by which to settle message receive + operations. If set to `PeekLock`, the receiver will lock a message once received until + the client accepts or rejects the message. If set to `ReceiveAndDelete`, the service + will assume successful receipt of the message and clear it from the queue. The + default is `PeekLock`. + :paramtype receive_settle_mode: ~pyamqp.constants.ReceiverSettleMode + :keyword desired_capabilities: The extension capabilities desired from the peer endpoint. + :paramtype desired_capabilities: list[bytes] + :keyword max_message_size: The maximum allowed message size negotiated for the Link. + :paramtype max_message_size: int + :keyword link_properties: Metadata to be sent in the Link ATTACH frame. + :paramtype link_properties: dict[str, any] + :keyword link_credit: The Link credit that determines how many + messages the Link will attempt to handle per connection iteration. + The default is 300. + :paramtype link_credit: int + :keyword transport_type: The type of transport protocol that will be used for communicating with + the service. Default is `TransportType.Amqp` in which case port 5671 is used. + If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could + be used instead which uses port 443 for communication. + :paramtype transport_type: ~pyamqp.constants.TransportType + :keyword http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. + :paramtype http_proxy: dict[str, str] + :keyword custom_endpoint_address: The custom endpoint address to use for establishing a connection to + the Event Hubs service, allowing network requests to be routed through any application gateways or + other paths needed for the host environment. Default is None. + If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. + :paramtype custom_endpoint_address: str + :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to + authenticate the identity of the connection endpoint. + Default is None in which case `certifi.where()` will be used. + :paramtype connection_verify: str + """ + + def __init__(self, hostname, **kwargs): + self._mgmt_link_lock_async = asyncio.Lock() + super().__init__(hostname,**kwargs) + + + async def _keep_alive_async(self): + start_time = time.time() + try: + while self._connection and not self._shutdown: + current_time = time.time() + elapsed_time = current_time - start_time + if elapsed_time >= self._keep_alive_interval: + _logger.debug( + "Keeping %r connection alive.", + self.__class__.__name__, + extra=self._network_trace_params + ) + await asyncio.shield(self._connection.listen(wait=self._socket_timeout, + batch=self._link.current_link_credit)) + start_time = current_time + await asyncio.sleep(1) + except Exception as e: # pylint: disable=broad-except + _logger.info( + "Connection keep-alive for %r failed: %r.", + self.__class__.__name__, + e, + extra=self._network_trace_params + ) + + async def __aenter__(self): + """Run Client in an async context manager.""" + await self.open_async() + return self + + async def __aexit__(self, *args): + """Close and destroy Client on exiting an async context manager.""" + await self.close_async() + + async def _client_ready_async(self): # pylint: disable=no-self-use + """Determine whether the client is ready to start sending and/or + receiving messages. To be ready, the connection must be open and + authentication complete. + + :rtype: bool + """ + return True + + async def _client_run_async(self, **kwargs): + """Perform a single Connection iteration.""" + await self._connection.listen(wait=self._socket_timeout, **kwargs) + + async def _close_link_async(self): + if self._link and not self._link._is_closed: # pylint: disable=protected-access + await self._link.detach(close=True) + self._link = None + + async def _do_retryable_operation_async(self, operation, *args, **kwargs): + retry_settings = self._retry_policy.configure_retries() + retry_active = True + absolute_timeout = kwargs.pop("timeout", 0) or 0 + start_time = time.time() + while retry_active: + try: + if absolute_timeout < 0: + raise TimeoutError("Operation timed out.") + return await operation(*args, timeout=absolute_timeout, **kwargs) + except AMQPException as exc: + if not self._retry_policy.is_retryable(exc): + raise + if absolute_timeout >= 0: + retry_active = self._retry_policy.increment(retry_settings, exc) + if not retry_active: + break + await asyncio.sleep(self._retry_policy.get_backoff_time(retry_settings, exc)) + if exc.condition == ErrorCondition.LinkDetachForced: + await self._close_link_async() # if link level error, close and open a new link + if exc.condition in (ErrorCondition.ConnectionCloseForced, ErrorCondition.SocketError): + # if connection detach or socket error, close and open a new connection + await self.close_async() + finally: + end_time = time.time() + if absolute_timeout > 0: + absolute_timeout -= (end_time - start_time) + raise retry_settings['history'][-1] + + async def open_async(self, connection=None): + """Asynchronously open the client. The client can create a new Connection + or an existing Connection can be passed in. This existing Connection + may have an existing CBS authentication Session, which will be + used for this client as well. Otherwise a new Session will be + created. + + :param connection: An existing Connection that may be shared between + multiple clients. + :type connection: ~pyamqp.aio.Connection + """ + # pylint: disable=protected-access + if self._session: + return # already open. + if connection: + self._connection = connection + self._external_connection = True + if not self._connection: + self._connection = Connection( + "amqps://" + self._hostname, + sasl_credential=self._auth.sasl, + ssl_opts={'ca_certs': self._connection_verify or certifi.where()}, + container_id=self._name, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout, + properties=self._properties, + network_trace=self._network_trace, + transport_type=self._transport_type, + http_proxy=self._http_proxy, + custom_endpoint_address=self._custom_endpoint_address + ) + await self._connection.open() + if not self._session: + self._session = self._connection.create_session( + incoming_window=self._incoming_window, + outgoing_window=self._outgoing_window + ) + await self._session.begin() + if self._auth.auth_type == AUTH_TYPE_CBS: + self._cbs_authenticator = CBSAuthenticator( + session=self._session, + auth=self._auth, + auth_timeout=self._auth_timeout + ) + await self._cbs_authenticator.open() + self._network_trace_params["amqpConnection"] = self._connection._container_id + self._network_trace_params["amqpSession"] = self._session.name + self._shutdown = False + + if self._keep_alive_interval: + self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_async()) + + async def close_async(self): + """Close the client asynchronously. This includes closing the Session + and CBS authentication layer as well as the Connection. + If the client was opened using an external Connection, + this will be left intact. + """ + self._shutdown = True + if not self._session: + return # already closed. + await self._close_link_async() + if self._cbs_authenticator: + await self._cbs_authenticator.close() + self._cbs_authenticator = None + await self._session.end() + self._session = None + if not self._external_connection: + await self._connection.close() + self._connection = None + if self._keep_alive_thread: + await self._keep_alive_thread + self._keep_alive_thread = None + self._network_trace_params["amqpConnection"] = None + self._network_trace_params["amqpSession"] = None + + async def auth_complete_async(self): + """Whether the authentication handshake is complete during + connection initialization. + + :rtype: bool + """ + if self._cbs_authenticator and not await self._cbs_authenticator.handle_token(): + await self._connection.listen(wait=self._socket_timeout) + return False + return True + + async def client_ready_async(self): + """ + Whether the handler has completed all start up processes such as + establishing the connection, session, link and authentication, and + is not ready to process messages. + + :rtype: bool + """ + if not await self.auth_complete_async(): + return False + if not await self._client_ready_async(): + try: + await self._connection.listen(wait=self._socket_timeout) + except ValueError: + return True + return False + return True + + async def do_work_async(self, **kwargs): + """Run a single connection iteration asynchronously. + This will return `True` if the connection is still open + and ready to be used for further work, or `False` if it needs + to be shut down. + + :rtype: bool + :raises: TimeoutError if CBS authentication timeout reached. + """ + + if self._shutdown: + return False + if not await self.client_ready_async(): + return True + return await self._client_run_async(**kwargs) + + async def mgmt_request_async(self, message, **kwargs): + """ + :param message: The message to send in the management request. + :type message: ~pyamqp.message.Message + :keyword str operation: The type of operation to be performed. This value will + be service-specific, but common values include READ, CREATE and UPDATE. + This value will be added as an application property on the message. + :keyword str operation_type: The type on which to carry out the operation. This will + be specific to the entities of the service. This value will be added as + an application property on the message. + :keyword str node: The target node. Default node is `$management`. + :keyword float timeout: Provide an optional timeout in seconds within which a response + to the management request must be received. + :rtype: ~pyamqp.message.Message + """ + + # The method also takes "status_code_field" and "status_description_field" + # keyword arguments as alternate names for the status code and description + # in the response body. Those two keyword arguments are used in Azure services only. + operation = kwargs.pop("operation", None) + operation_type = kwargs.pop("operation_type", None) + node = kwargs.pop("node", "$management") + timeout = kwargs.pop('timeout', 0) + async with self._mgmt_link_lock_async: + try: + mgmt_link = self._mgmt_links[node] + except KeyError: + mgmt_link = ManagementOperation(self._session, endpoint=node, **kwargs) + self._mgmt_links[node] = mgmt_link + await mgmt_link.open() + + while not await mgmt_link.ready(): + await self._connection.listen(wait=False) + + operation_type = operation_type or b'empty' + status, description, response = await mgmt_link.execute( + message, + operation=operation, + operation_type=operation_type, + timeout=timeout + ) + return status, description, response + + +class SendClientAsync(SendClientSync, AMQPClientAsync): + + """An asynchronous AMQP client. + + :param target: The target AMQP service endpoint. This can either be the URI as + a string or a ~pyamqp.endpoint.Target object. + :type target: str, bytes or ~pyamqp.endpoint.Target + :keyword auth: Authentication for the connection. This should be one of the following: + - pyamqp.authentication.SASLAnonymous + - pyamqp.authentication.SASLPlain + - pyamqp.authentication.SASTokenAuth + - pyamqp.authentication.JWTTokenAuth + If no authentication is supplied, SASLAnnoymous will be used by default. + :paramtype auth: ~pyamqp.authentication + :keyword client_name: The name for the client, also known as the Container ID. + If no name is provided, a random GUID will be used. + :paramtype client_name: str or bytes + :keyword network_trace: Whether to turn on network trace logs. If `True`, trace logs + will be logged at INFO level. Default is `False`. + :paramtype network_trace: bool + :keyword retry_policy: A policy for parsing errors on link, connection and message + disposition to determine whether the error should be retryable. + :paramtype retry_policy: ~pyamqp.error.RetryPolicy + :keyword keep_alive_interval: If set, a thread will be started to keep the connection + alive during periods of user inactivity. The value will determine how long the + thread will sleep (in seconds) between pinging the connection. If 0 or None, no + thread will be started. + :paramtype keep_alive_interval: int + :keyword max_frame_size: Maximum AMQP frame size. Default is 63488 bytes. + :paramtype max_frame_size: int + :keyword channel_max: Maximum number of Session channels in the Connection. + :paramtype channel_max: int + :keyword idle_timeout: Timeout in seconds after which the Connection will close + if there is no further activity. + :paramtype idle_timeout: int + :keyword auth_timeout: Timeout in seconds for CBS authentication. Otherwise this value will be ignored. + Default value is 60s. + :paramtype auth_timeout: int + :keyword properties: Connection properties. + :paramtype properties: dict[str, any] + :keyword remote_idle_timeout_empty_frame_send_ratio: Ratio of empty frames to + idle time for Connections with no activity. Value must be between + 0.0 and 1.0 inclusive. Default is 0.5. + :paramtype remote_idle_timeout_empty_frame_send_ratio: float + :keyword incoming_window: The size of the allowed window for incoming messages. + :paramtype incoming_window: int + :keyword outgoing_window: The size of the allowed window for outgoing messages. + :paramtype outgoing_window: int + :keyword handle_max: The maximum number of concurrent link handles. + :paramtype handle_max: int + :keyword on_attach: A callback function to be run on receipt of an ATTACH frame. + The function must take 4 arguments: source, target, properties and error. + :paramtype on_attach: func[ + ~pyamqp.endpoint.Source, ~pyamqp.endpoint.Target, dict, ~pyamqp.error.AMQPConnectionError] + :keyword send_settle_mode: The mode by which to settle message send + operations. If set to `Unsettled`, the client will wait for a confirmation + from the service that the message was successfully sent. If set to 'Settled', + the client will not wait for confirmation and assume success. + :paramtype send_settle_mode: ~pyamqp.constants.SenderSettleMode + :keyword receive_settle_mode: The mode by which to settle message receive + operations. If set to `PeekLock`, the receiver will lock a message once received until + the client accepts or rejects the message. If set to `ReceiveAndDelete`, the service + will assume successful receipt of the message and clear it from the queue. The + default is `PeekLock`. + :paramtype receive_settle_mode: ~pyamqp.constants.ReceiverSettleMode + :keyword desired_capabilities: The extension capabilities desired from the peer endpoint. + :paramtype desired_capabilities: list[bytes] + :keyword max_message_size: The maximum allowed message size negotiated for the Link. + :paramtype max_message_size: int + :keyword link_properties: Metadata to be sent in the Link ATTACH frame. + :paramtype link_properties: dict[str, any] + :keyword link_credit: The Link credit that determines how many + messages the Link will attempt to handle per connection iteration. + The default is 300. + :paramtype link_credit: int + :keyword transport_type: The type of transport protocol that will be used for communicating with + the service. Default is `TransportType.Amqp` in which case port 5671 is used. + If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could + be used instead which uses port 443 for communication. + :paramtype transport_type: ~pyamqp.constants.TransportType + :keyword http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. + :paramtype http_proxy: dict[str, str] + :keyword custom_endpoint_address: The custom endpoint address to use for establishing a connection to + the Event Hubs service, allowing network requests to be routed through any application gateways or + other paths needed for the host environment. Default is None. + If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. + :paramtype custom_endpoint_address: str + :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to + authenticate the identity of the connection endpoint. + Default is None in which case `certifi.where()` will be used. + :paramtype connection_verify: str + """ + + async def _client_ready_async(self): + """Determine whether the client is ready to start receiving messages. + To be ready, the connection must be open and authentication complete, + The Session, Link and MessageReceiver must be open and in non-errored + states. + + :rtype: bool + """ + # pylint: disable=protected-access + if not self._link: + self._link = self._session.create_sender_link( + target_address=self.target, + link_credit=self._link_credit, + send_settle_mode=self._send_settle_mode, + rcv_settle_mode=self._receive_settle_mode, + max_message_size=self._max_message_size, + properties=self._link_properties) + await self._link.attach() + return False + if self._link.get_state().value != 3: # ATTACHED + return False + return True + + async def _client_run_async(self, **kwargs): + """MessageSender Link is now open - perform message send + on all pending messages. + Will return True if operation successful and client can remain open for + further work. + + :rtype: bool + """ + await self._link.update_pending_deliveries() + await self._connection.listen(wait=self._socket_timeout, **kwargs) + return True + + async def _transfer_message_async(self, message_delivery, timeout=0): + message_delivery.state = MessageDeliveryState.WaitingForSendAck + on_send_complete = partial(self._on_send_complete_async, message_delivery) + delivery = await self._link.send_transfer( + message_delivery.message, + on_send_complete=on_send_complete, + timeout=timeout, + send_async=True + ) + return delivery + + async def _on_send_complete_async(self, message_delivery, reason, state): + message_delivery.reason = reason + if reason == LinkDeliverySettleReason.DISPOSITION_RECEIVED: + if state and SEND_DISPOSITION_ACCEPT in state: + message_delivery.state = MessageDeliveryState.Ok + else: + try: + error_info = state[SEND_DISPOSITION_REJECT] + self._process_send_error( + message_delivery, + condition=error_info[0][0], + description=error_info[0][1], + info=error_info[0][2] + ) + except TypeError: + self._process_send_error( + message_delivery, + condition=ErrorCondition.UnknownError + ) + elif reason == LinkDeliverySettleReason.SETTLED: + message_delivery.state = MessageDeliveryState.Ok + elif reason == LinkDeliverySettleReason.TIMEOUT: + message_delivery.state = MessageDeliveryState.Timeout + message_delivery.error = TimeoutError("Sending message timed out.") + else: + # NotDelivered and other unknown errors + self._process_send_error( + message_delivery, + condition=ErrorCondition.UnknownError + ) + + async def _send_message_impl_async(self, message, **kwargs): + timeout = kwargs.pop("timeout", 0) + expire_time = (time.time() + timeout) if timeout else None + await self.open_async() + message_delivery = _MessageDelivery( + message, + MessageDeliveryState.WaitingToBeSent, + expire_time + ) + + while not await self.client_ready_async(): + await asyncio.sleep(0.05) + + await self._transfer_message_async(message_delivery, timeout) + + running = True + while running and message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: + running = await self.do_work_async() + if message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: + raise MessageException( + condition=ErrorCondition.ClientError, + description="Send failed - connection not running." + ) + + if message_delivery.state in ( + MessageDeliveryState.Error, + MessageDeliveryState.Cancelled, + MessageDeliveryState.Timeout + ): + try: + raise message_delivery.error # pylint: disable=raising-bad-type + except TypeError: + # This is a default handler + raise MessageException(condition=ErrorCondition.UnknownError, description="Send failed.") + + async def send_message_async(self, message, **kwargs): + """ + :param ~pyamqp.message.Message message: + :param int timeout: timeout in seconds + """ + await self._do_retryable_operation_async(self._send_message_impl_async, message=message, **kwargs) + + +class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync): + """An asynchronous AMQP client. + + :param source: The source AMQP service endpoint. This can either be the URI as + a string or a ~pyamqp.endpoint.Source object. + :type source: str, bytes or ~pyamqp.endpoint.Source + :keyword auth: Authentication for the connection. This should be one of the following: + - pyamqp.authentication.SASLAnonymous + - pyamqp.authentication.SASLPlain + - pyamqp.authentication.SASTokenAuth + - pyamqp.authentication.JWTTokenAuth + If no authentication is supplied, SASLAnnoymous will be used by default. + :paramtype auth: ~pyamqp.authentication + :keyword client_name: The name for the client, also known as the Container ID. + If no name is provided, a random GUID will be used. + :paramtype client_name: str or bytes + :keyword network_trace: Whether to turn on network trace logs. If `True`, trace logs + will be logged at INFO level. Default is `False`. + :paramtype network_trace: bool + :keyword retry_policy: A policy for parsing errors on link, connection and message + disposition to determine whether the error should be retryable. + :paramtype retry_policy: ~pyamqp.error.RetryPolicy + :keyword keep_alive_interval: If set, a thread will be started to keep the connection + alive during periods of user inactivity. The value will determine how long the + thread will sleep (in seconds) between pinging the connection. If 0 or None, no + thread will be started. + :paramtype keep_alive_interval: int + :keyword max_frame_size: Maximum AMQP frame size. Default is 63488 bytes. + :paramtype max_frame_size: int + :keyword channel_max: Maximum number of Session channels in the Connection. + :paramtype channel_max: int + :keyword idle_timeout: Timeout in seconds after which the Connection will close + if there is no further activity. + :paramtype idle_timeout: int + :keyword auth_timeout: Timeout in seconds for CBS authentication. Otherwise this value will be ignored. + Default value is 60s. + :paramtype auth_timeout: int + :keyword properties: Connection properties. + :paramtype properties: dict[str, any] + :keyword remote_idle_timeout_empty_frame_send_ratio: Ratio of empty frames to + idle time for Connections with no activity. Value must be between + 0.0 and 1.0 inclusive. Default is 0.5. + :paramtype remote_idle_timeout_empty_frame_send_ratio: float + :keyword incoming_window: The size of the allowed window for incoming messages. + :paramtype incoming_window: int + :keyword outgoing_window: The size of the allowed window for outgoing messages. + :paramtype outgoing_window: int + :keyword handle_max: The maximum number of concurrent link handles. + :paramtype handle_max: int + :keyword on_attach: A callback function to be run on receipt of an ATTACH frame. + The function must take 4 arguments: source, target, properties and error. + :paramtype on_attach: func[ + ~pyamqp.endpoint.Source, ~pyamqp.endpoint.Target, dict, ~pyamqp.error.AMQPConnectionError] + :keyword send_settle_mode: The mode by which to settle message send + operations. If set to `Unsettled`, the client will wait for a confirmation + from the service that the message was successfully sent. If set to 'Settled', + the client will not wait for confirmation and assume success. + :paramtype send_settle_mode: ~pyamqp.constants.SenderSettleMode + :keyword receive_settle_mode: The mode by which to settle message receive + operations. If set to `PeekLock`, the receiver will lock a message once received until + the client accepts or rejects the message. If set to `ReceiveAndDelete`, the service + will assume successful receipt of the message and clear it from the queue. The + default is `PeekLock`. + :paramtype receive_settle_mode: ~pyamqp.constants.ReceiverSettleMode + :keyword desired_capabilities: The extension capabilities desired from the peer endpoint. + :paramtype desired_capabilities: list[bytes] + :keyword max_message_size: The maximum allowed message size negotiated for the Link. + :paramtype max_message_size: int + :keyword link_properties: Metadata to be sent in the Link ATTACH frame. + :paramtype link_properties: dict[str, any] + :keyword link_credit: The Link credit that determines how many + messages the Link will attempt to handle per connection iteration. + The default is 300. + :paramtype link_credit: int + :keyword transport_type: The type of transport protocol that will be used for communicating with + the service. Default is `TransportType.Amqp` in which case port 5671 is used. + If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could + be used instead which uses port 443 for communication. + :paramtype transport_type: ~pyamqp.constants.TransportType + :keyword http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. + :paramtype http_proxy: dict[str, str] + :keyword custom_endpoint_address: The custom endpoint address to use for establishing a connection to + the Event Hubs service, allowing network requests to be routed through any application gateways or + other paths needed for the host environment. Default is None. + If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. + :paramtype custom_endpoint_address: str + :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to + authenticate the identity of the connection endpoint. + Default is None in which case `certifi.where()` will be used. + :paramtype connection_verify: str + """ + + async def _client_ready_async(self): + """Determine whether the client is ready to start receiving messages. + To be ready, the connection must be open and authentication complete, + The Session, Link and MessageReceiver must be open and in non-errored + states. + + :rtype: bool + """ + # pylint: disable=protected-access + if not self._link: + self._link = self._session.create_receiver_link( + source_address=self.source, + link_credit=self._link_credit, + send_settle_mode=self._send_settle_mode, + rcv_settle_mode=self._receive_settle_mode, + max_message_size=self._max_message_size, + on_transfer=self._message_received_async, + properties=self._link_properties, + desired_capabilities=self._desired_capabilities, + on_attach=self._on_attach + ) + await self._link.attach() + return False + if self._link.get_state().value != 3: # ATTACHED + return False + return True + + async def _client_run_async(self, **kwargs): + """MessageReceiver Link is now open - start receiving messages. + Will return True if operation successful and client can remain open for + further work. + + :rtype: bool + """ + try: + if self._link.current_link_credit == 0: + await self._link.flow() + await self._connection.listen(wait=self._socket_timeout, **kwargs) + except ValueError: + _logger.info("Timeout reached, closing receiver.", extra=self._network_trace_params) + self._shutdown = True + return False + return True + + async def _message_received_async(self, frame, message): + """Callback run on receipt of every message. If there is + a user-defined callback, this will be called. + Additionally if the client is retrieving messages for a batch + or iterator, the message will be added to an internal queue. + + :param message: Received message. + :type message: ~pyamqp.message.Message + """ + self._last_activity_timestamp = time.time() + if self._message_received_callback: + await self._message_received_callback(message) + if not self._streaming_receive: + self._received_messages.put((frame, message)) + + async def _receive_message_batch_impl_async(self, max_batch_size=None, on_message_received=None, timeout=0): + self._message_received_callback = on_message_received + max_batch_size = max_batch_size or self._link_credit + timeout_time = time.time() + timeout if timeout else 0 + receiving = True + batch = [] + await self.open_async() + while len(batch) < max_batch_size: + try: + # TODO: This drops the transfer frame data + _, message = self._received_messages.get_nowait() + batch.append(message) + self._received_messages.task_done() + except queue.Empty: + break + else: + return batch + + to_receive_size = max_batch_size - len(batch) + before_queue_size = self._received_messages.qsize() + + while receiving and to_receive_size > 0: + now_time = time.time() + if timeout_time and now_time > timeout_time: + break + + try: + receiving = await asyncio.wait_for( + self.do_work_async(batch=to_receive_size), + timeout=timeout_time - now_time if timeout else None + ) + except asyncio.TimeoutError: + break + + cur_queue_size = self._received_messages.qsize() + # after do_work, check how many new messages have been received since previous iteration + received = cur_queue_size - before_queue_size + if to_receive_size < max_batch_size and received == 0: + # there are already messages in the batch, and no message is received in the current cycle + # return what we have + break + + to_receive_size -= received + before_queue_size = cur_queue_size + + while len(batch) < max_batch_size: + try: + _, message = self._received_messages.get_nowait() + batch.append(message) + self._received_messages.task_done() + except queue.Empty: + break + return batch + + async def close_async(self): + self._received_messages = queue.Queue() + await super(ReceiveClientAsync, self).close_async() + + async def receive_message_batch_async(self, **kwargs): + """Receive a batch of messages. Messages returned in the batch have already been + accepted - if you wish to add logic to accept or reject messages based on custom + criteria, pass in a callback. This method will return as soon as some messages are + available rather than waiting to achieve a specific batch size, and therefore the + number of messages returned per call will vary up to the maximum allowed. + + :keyword max_batch_size: The maximum number of messages that can be returned in + one call. This value cannot be larger than the prefetch value, and if not specified, + the prefetch value will be used. + :paramtype max_batch_size: int + :keyword on_message_received: A callback to process messages as they arrive from the + service. It takes a single argument, a ~pyamqp.message.Message object. + :paramtype on_message_received: callable[~pyamqp.message.Message] + :keyword timeout: Timeout in seconds for which to wait to receive any messages. + If no messages are received in this time, an empty list will be returned. If set to + 0, the client will continue to wait until at least one message is received. The + default is 0. + :paramtype timeout: float + """ + return await self._do_retryable_operation_async( + self._receive_message_batch_impl_async, + **kwargs + ) + + async def receive_messages_iter_async(self, timeout=None, on_message_received=None): + """Receive messages by generator. Messages returned in the generator have already been + accepted - if you wish to add logic to accept or reject messages based on custom + criteria, pass in a callback. + + :param on_message_received: A callback to process messages as they arrive from the + service. It takes a single argument, a ~pyamqp.message.Message object. + :type on_message_received: callable[~pyamqp.message.Message] + """ + self._message_received_callback = on_message_received + return self._message_generator_async(timeout=timeout) + + async def _message_generator_async(self, timeout=None): + """Iterate over processed messages in the receive queue. + + :rtype: generator[~pyamqp.message.Message] + """ + self.open() + receiving = True + message = None + self._last_activity_timestamp = time.time() + self._timeout_reached = False + self._timeout = timeout if timeout else self._timeout + try: + while receiving and not self._timeout_reached: + if self._timeout > 0: + if time.time() - self._last_activity_timestamp >= self._timeout: + self._timeout_reached = True + + if not self._timeout_reached: + receiving = await self.do_work_async() + + while not self._received_messages.empty(): + message = self._received_messages.get() + self._last_activity_timestamp = time.time() + self._received_messages.task_done() + yield message + + finally: + if self._shutdown: + await self.close_async() + + @overload + async def settle_messages_async( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["accepted"], + *, + batchable: Optional[bool] = None + ): + ... + + @overload + async def settle_messages_async( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["released"], + *, + batchable: Optional[bool] = None + ): + ... + + @overload + async def settle_messages_async( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["rejected"], + *, + error: Optional[AMQPError] = None, + batchable: Optional[bool] = None + ): + ... + + @overload + async def settle_messages_async( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["modified"], + *, + delivery_failed: Optional[bool] = None, + undeliverable_here: Optional[bool] = None, + message_annotations: Optional[Dict[Union[str, bytes], Any]] = None, + batchable: Optional[bool] = None + ): + ... + + @overload + async def settle_messages_async( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["received"], + *, + section_number: int, + section_offset: int, + batchable: Optional[bool] = None + ): + ... + + async def settle_messages_async(self, delivery_id: Union[int, Tuple[int, int]], outcome: str, **kwargs): + batchable = kwargs.pop('batchable', None) + if outcome.lower() == 'accepted': + state: Outcomes = Accepted() + elif outcome.lower() == 'released': + state = Released() + elif outcome.lower() == 'rejected': + state = Rejected(**kwargs) + elif outcome.lower() == 'modified': + state = Modified(**kwargs) + elif outcome.lower() == 'received': + state = Received(**kwargs) + else: + raise ValueError("Unrecognized message output: {}".format(outcome)) + try: + first, last = cast(Tuple, delivery_id) + except TypeError: + first = delivery_id + last = None + await self._link.send_disposition( + first_delivery_id=first, + last_delivery_id=last, + settled=True, + delivery_state=state, + batchable=batchable, + wait=True + ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_connection_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_connection_async.py new file mode 100644 index 0000000000000..9714cdd9ed85f --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_connection_async.py @@ -0,0 +1,874 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import uuid +import logging +import time +from urllib.parse import urlparse +import socket +from ssl import SSLError +import asyncio +from typing import Any, Dict, Tuple, Optional, NamedTuple, Union, cast + +from ._transport_async import AsyncTransport +from ._sasl_async import SASLTransport, SASLWithWebSocket +from ._session_async import Session +from ..performatives import OpenFrame, CloseFrame +from .._connection import get_local_timeout, _CLOSING_STATES +from ..constants import ( + PORT, + SECURE_PORT, + WEBSOCKET_PORT, + MAX_CHANNELS, + MAX_FRAME_SIZE_BYTES, + HEADER_FRAME, + ConnectionState, + EMPTY_FRAME, + TransportType, + READ_TIMEOUT_INTERVAL +) + +from ..error import ErrorCondition, AMQPConnectionError, AMQPError + +_LOGGER = logging.getLogger(__name__) + + +class Connection(object): # pylint:disable=too-many-instance-attributes + """An AMQP Connection. + + :ivar str state: The connection state. + :param str endpoint: The endpoint to connect to. Must be fully qualified with scheme and port number. + :keyword str container_id: The ID of the source container. If not set a GUID will be generated. + :keyword int max_frame_size: Proposed maximum frame size in bytes. Default value is 64kb. + :keyword int channel_max: The maximum channel number that may be used on the Connection. Default value is 65535. + :keyword int idle_timeout: Connection idle time-out in seconds. + :keyword list(str) outgoing_locales: Locales available for outgoing text. + :keyword list(str) incoming_locales: Desired locales for incoming text in decreasing level of preference. + :keyword list(str) offered_capabilities: The extension capabilities the sender supports. + :keyword list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports + :keyword dict properties: Connection properties. + :keyword bool allow_pipelined_open: Allow frames to be sent on the connection before a response Open frame + has been received. Default value is `True`. + :keyword float idle_timeout_empty_frame_send_ratio: Portion of the idle timeout time to wait before sending an + empty frame. The default portion is 50% of the idle timeout value (i.e. `0.5`). + :keyword float idle_wait_time: The time in seconds to sleep while waiting for a response from the endpoint. + Default value is `0.1`. + :keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames + will be logged at the logging.INFO level. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. + Additionally the following keys may also be present: `'username', 'password'`. + """ + + def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements + # type(str, Any) -> None + parsed_url = urlparse(endpoint) + self._hostname = parsed_url.hostname + endpoint = self._hostname + if parsed_url.port: + self._port = parsed_url.port + elif parsed_url.scheme == "amqps": + self._port = SECURE_PORT + else: + self._port = PORT + self.state = None # type: Optional[ConnectionState] + + # Custom Endpoint + custom_endpoint_address = kwargs.get("custom_endpoint_address") + custom_endpoint = None + if custom_endpoint_address: + custom_parsed_url = urlparse(custom_endpoint_address) + custom_port = custom_parsed_url.port or WEBSOCKET_PORT + custom_endpoint = f"{custom_parsed_url.hostname}:{custom_port}{custom_parsed_url.path}" + self._container_id = kwargs.pop("container_id", None) or str( + uuid.uuid4() + ) # type: str + self._network_trace = kwargs.get("network_trace", False) + self._network_trace_params = {"amqpConnection": self._container_id, "amqpSession": None, "amqpLink": None} + + transport = kwargs.get("transport") + self._transport_type = kwargs.pop("transport_type", TransportType.Amqp) + if transport: + self._transport = transport + elif "sasl_credential" in kwargs: + sasl_transport = SASLTransport + if self._transport_type.name == "AmqpOverWebsocket" or kwargs.get( + "http_proxy" + ): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path + self._transport = sasl_transport( + host=endpoint, + credential=kwargs["sasl_credential"], + custom_endpoint=custom_endpoint, + network_trace_params=self._network_trace_params, + **kwargs, + ) + else: + self._transport = AsyncTransport( + parsed_url.netloc, + network_trace_params=self._network_trace_params, + **kwargs) + + self._max_frame_size = kwargs.pop( + "max_frame_size", MAX_FRAME_SIZE_BYTES + ) # type: int + self._remote_max_frame_size = None # type: Optional[int] + self._channel_max = kwargs.pop("channel_max", MAX_CHANNELS) # type: int + self._idle_timeout = kwargs.pop("idle_timeout", None) # type: Optional[int] + self._outgoing_locales = kwargs.pop( + "outgoing_locales", None + ) # type: Optional[List[str]] + self._incoming_locales = kwargs.pop( + "incoming_locales", None + ) # type: Optional[List[str]] + self._offered_capabilities = None # type: Optional[str] + self._desired_capabilities = kwargs.pop( + "desired_capabilities", None + ) # type: Optional[str] + self._properties = kwargs.pop( + "properties", None + ) # type: Optional[Dict[str, str]] + + self._remote_properties: Optional[Dict[str, str]] = None + + self._allow_pipelined_open = kwargs.pop( + "allow_pipelined_open", True + ) # type: bool + self._remote_idle_timeout = None # type: Optional[int] + self._remote_idle_timeout_send_frame = None # type: Optional[int] + self._idle_timeout_empty_frame_send_ratio = kwargs.get( + "idle_timeout_empty_frame_send_ratio", 0.5 + ) + self._last_frame_received_time = None # type: Optional[float] + self._last_frame_sent_time = None # type: Optional[float] + self._idle_wait_time = kwargs.get("idle_wait_time", 0.1) # type: float + self._error = None + self._outgoing_endpoints = {} # type: Dict[int, Session] + self._incoming_endpoints = {} # type: Dict[int, Session] + + async def __aenter__(self): + await self.open() + return self + + async def __aexit__(self, *args): + await self.close() + + async def _set_state(self, new_state): + # type: (ConnectionState) -> None + """Update the connection state.""" + if new_state is None: + return + previous_state = self.state + self.state = new_state + _LOGGER.info( + "Connection state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) + for session in self._outgoing_endpoints.values(): + await session._on_connection_state_change() # pylint:disable=protected-access + + async def _connect(self): + # type: () -> None + """Initiate the connection. + + If `allow_pipelined_open` is enabled, the incoming response header will be processed immediately + and the state on exiting will be HDR_EXCH. Otherwise, the function will return before waiting for + the response header and the final state will be HDR_SENT. + + :raises ValueError: If a reciprocating protocol header is not received during negotiation. + """ + try: + if not self.state: + await self._transport.connect() + await self._set_state(ConnectionState.START) + await self._transport.negotiate() + await self._outgoing_header() + await self._set_state(ConnectionState.HDR_SENT) + if not self._allow_pipelined_open: + await self._read_frame(wait=True) + if self.state != ConnectionState.HDR_EXCH: + await self._disconnect() + raise ValueError( + "Did not receive reciprocal protocol header. Disconnecting." + ) + else: + await self._set_state(ConnectionState.HDR_SENT) + except (OSError, IOError, SSLError, socket.error, asyncio.TimeoutError) as exc: + # FileNotFoundError is being raised for exception parity with uamqp when invalid + # `connection_verify` file path is passed in. Remove later when resolving issue #27128. + if isinstance(exc, FileNotFoundError) and exc.filename and "ca_certs" in exc.filename: + raise + raise AMQPConnectionError( + ErrorCondition.SocketError, + description="Failed to initiate the connection due to exception: " + + str(exc), + error=exc, + ) + + async def _disconnect(self) -> None: + """Disconnect the transport and set state to END.""" + if self.state == ConnectionState.END: + return + await self._set_state(ConnectionState.END) + await self._transport.close() + + def _can_read(self): + # type: () -> bool + """Whether the connection is in a state where it is legal to read for incoming frames.""" + return self.state not in (ConnectionState.CLOSE_RCVD, ConnectionState.END) + + async def _read_frame(self, wait: Union[bool, int, float] = True, **kwargs) -> bool: + """Read an incoming frame from the transport. + + :param Union[bool, float] wait: Whether to block on the socket while waiting for an incoming frame. + The default value is `False`, where the frame will block for the configured timeout only (0.1 seconds). + If set to `True`, socket will block indefinitely. If set to a timeout value in seconds, the socket will + block for at most that value. + :rtype: Tuple[int, Optional[Tuple[int, NamedTuple]]] + :returns: A tuple with the incoming channel number, and the frame in the form or a tuple of performative + descriptor and field values. + """ + timeout: Optional[Union[int, float]] = None + if wait is False: + timeout = READ_TIMEOUT_INTERVAL + elif wait is True: + timeout = None + else: + timeout = wait + new_frame = await self._transport.receive_frame(timeout=timeout, **kwargs) + return await self._process_incoming_frame(*new_frame) + + def _can_write(self): + # type: () -> bool + """Whether the connection is in a state where it is legal to write outgoing frames.""" + return self.state not in _CLOSING_STATES + + async def _send_frame(self, channel, frame, timeout=None, **kwargs): + # type: (int, NamedTuple, Optional[int], Any) -> None + """Send a frame over the connection. + + :param int channel: The outgoing channel number. + :param NamedTuple: The outgoing frame. + :param int timeout: An optional timeout value to wait until the socket is ready to send the frame. + :rtype: None + """ + try: + raise self._error + except TypeError: + pass + + if self._can_write(): + try: + self._last_frame_sent_time = time.time() + await asyncio.wait_for( + self._transport.send_frame(channel, frame, **kwargs), + timeout=timeout, + ) + except ( + OSError, + IOError, + SSLError, + socket.error, + asyncio.TimeoutError, + ) as exc: + self._error = AMQPConnectionError( + ErrorCondition.SocketError, + description="Can not send frame out due to exception: " + str(exc), + error=exc, + ) + else: + _LOGGER.info("Cannot write frame in current state: %r", self.state, extra=self._network_trace_params) + + def _get_next_outgoing_channel(self): + # type: () -> int + """Get the next available outgoing channel number within the max channel limit. + + :raises ValueError: If maximum channels has been reached. + :returns: The next available outgoing channel number. + :rtype: int + """ + if ( + len(self._incoming_endpoints) + len(self._outgoing_endpoints) + ) >= self._channel_max: + raise ValueError( + "Maximum number of channels ({}) has been reached.".format( + self._channel_max + ) + ) + next_channel = next( + i for i in range(1, self._channel_max) if i not in self._outgoing_endpoints + ) + return next_channel + + async def _outgoing_empty(self): + # type: () -> None + """Send an empty frame to prevent the connection from reaching an idle timeout.""" + if self._network_trace: + _LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params) + try: + raise self._error + except TypeError: + pass + try: + if self._can_write(): + await self._transport.write(EMPTY_FRAME) + self._last_frame_sent_time = time.time() + except (OSError, IOError, SSLError, socket.error) as exc: + self._error = AMQPConnectionError( + ErrorCondition.SocketError, + description="Can not send empty frame due to exception: " + str(exc), + error=exc, + ) + + async def _outgoing_header(self): + # type: () -> None + """Send the AMQP protocol header to initiate the connection.""" + self._last_frame_sent_time = time.time() + if self._network_trace: + _LOGGER.debug("-> Header(%r)", HEADER_FRAME, extra=self._network_trace_params) + await self._transport.write(HEADER_FRAME) + + async def _incoming_header(self, _, frame): + # type: (int, bytes) -> None + """Process an incoming AMQP protocol header and update the connection state.""" + if self._network_trace: + _LOGGER.debug("<- Header(%r)", frame, extra=self._network_trace_params) + if self.state == ConnectionState.START: + await self._set_state(ConnectionState.HDR_RCVD) + elif self.state == ConnectionState.HDR_SENT: + await self._set_state(ConnectionState.HDR_EXCH) + elif self.state == ConnectionState.OPEN_PIPE: + await self._set_state(ConnectionState.OPEN_SENT) + + async def _outgoing_open(self): + # type: () -> None + """Send an Open frame to negotiate the AMQP connection functionality.""" + open_frame = OpenFrame( + container_id=self._container_id, + hostname=self._hostname, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout * 1000 + if self._idle_timeout + else None, # Convert to milliseconds + outgoing_locales=self._outgoing_locales, + incoming_locales=self._incoming_locales, + offered_capabilities=self._offered_capabilities + if self.state == ConnectionState.OPEN_RCVD + else None, + desired_capabilities=self._desired_capabilities + if self.state == ConnectionState.HDR_EXCH + else None, + properties=self._properties, + ) + if self._network_trace: + _LOGGER.debug("-> %r", open_frame, extra=self._network_trace_params) + await self._send_frame(0, open_frame) + + async def _incoming_open(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming Open frame to finish the connection negotiation. + + The incoming frame format is:: + + - frame[0]: container_id (str) + - frame[1]: hostname (str) + - frame[2]: max_frame_size (int) + - frame[3]: channel_max (int) + - frame[4]: idle_timeout (Optional[int]) + - frame[5]: outgoing_locales (Optional[List[bytes]]) + - frame[6]: incoming_locales (Optional[List[bytes]]) + - frame[7]: offered_capabilities (Optional[List[bytes]]) + - frame[8]: desired_capabilities (Optional[List[bytes]]) + - frame[9]: properties (Optional[Dict[bytes, bytes]]) + + :param int channel: The incoming channel number. + :param frame: The incoming Open frame. + :type frame: Tuple[Any, ...] + :rtype: None + """ + # TODO: Add type hints for full frame tuple contents. + if self._network_trace: + _LOGGER.debug("<- %r", OpenFrame(*frame), extra=self._network_trace_params) + if channel != 0: + _LOGGER.error("OPEN frame received on a channel that is not 0.", extra=self._network_trace_params) + await self.close( + error=AMQPError( + condition=ErrorCondition.NotAllowed, + description="OPEN frame received on a channel that is not 0.", + ) + ) + await self._set_state(ConnectionState.END) + if self.state == ConnectionState.OPENED: + _LOGGER.error("OPEN frame received in the OPENED state.", extra=self._network_trace_params) + await self.close() + if frame[4]: + self._remote_idle_timeout = frame[4] / 1000 # Convert to seconds + self._remote_idle_timeout_send_frame = ( + self._idle_timeout_empty_frame_send_ratio * self._remote_idle_timeout + ) + + if frame[2] < 512: + # Max frame size is less than supported minimum + # If any of the values in the received open frame are invalid then the connection shall be closed. + # The error amqp:invalid-field shall be set in the error.condition field of the CLOSE frame. + await self.close( + error=AMQPError( + condition=ErrorCondition.InvalidField, + description="Failed parsing OPEN frame: Max frame size is less than supported minimum.", + ) + ) + _LOGGER.error( + "Failed parsing OPEN frame: Max frame size is less than supported minimum.", + extra=self._network_trace_params + ) + return + self._remote_max_frame_size = frame[2] + self._remote_properties = frame[9] + if self.state == ConnectionState.OPEN_SENT: + await self._set_state(ConnectionState.OPENED) + elif self.state == ConnectionState.HDR_EXCH: + await self._set_state(ConnectionState.OPEN_RCVD) + await self._outgoing_open() + await self._set_state(ConnectionState.OPENED) + else: + await self.close( + error=AMQPError( + condition=ErrorCondition.IllegalState, + description=f"Connection is an illegal state: {self.state}", + ) + ) + _LOGGER.error("Connection is an illegal state: %r", self.state, extra=self._network_trace_params) + + async def _outgoing_close(self, error=None): + # type: (Optional[AMQPError]) -> None + """Send a Close frame to shutdown connection with optional error information.""" + close_frame = CloseFrame(error=error) + if self._network_trace: + _LOGGER.debug("-> %r", close_frame, extra=self._network_trace_params) + await self._send_frame(0, close_frame) + + async def _incoming_close(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming Open frame to finish the connection negotiation. + + The incoming frame format is:: + + - frame[0]: error (Optional[AMQPError]) + + """ + if self._network_trace: + _LOGGER.debug("<- %r", CloseFrame(*frame), extra=self._network_trace_params) + disconnect_states = [ + ConnectionState.HDR_RCVD, + ConnectionState.HDR_EXCH, + ConnectionState.OPEN_RCVD, + ConnectionState.CLOSE_SENT, + ConnectionState.DISCARDING, + ] + if self.state in disconnect_states: + await self._disconnect() + return + + close_error = None + if channel > self._channel_max: + _LOGGER.error( + "CLOSE frame received on a channel greated than support max.", + extra=self._network_trace_params + ) + close_error = AMQPError( + condition=ErrorCondition.InvalidField, + description="Invalid channel", + info=None, + ) + + await self._set_state(ConnectionState.CLOSE_RCVD) + await self._outgoing_close(error=close_error) + await self._disconnect() + + if frame[0]: + self._error = AMQPConnectionError( + condition=frame[0][0], description=frame[0][1], info=frame[0][2] + ) + _LOGGER.error( + "Connection closed with error: %r", frame[0], + extra=self._network_trace_params + ) + + async def _incoming_begin(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming Begin frame to finish negotiating a new session. + + The incoming frame format is:: + + - frame[0]: remote_channel (int) + - frame[1]: next_outgoing_id (int) + - frame[2]: incoming_window (int) + - frame[3]: outgoing_window (int) + - frame[4]: handle_max (int) + - frame[5]: offered_capabilities (Optional[List[bytes]]) + - frame[6]: desired_capabilities (Optional[List[bytes]]) + - frame[7]: properties (Optional[Dict[bytes, bytes]]) + + :param int channel: The incoming channel number. + :param frame: The incoming Begin frame. + :type frame: Tuple[Any, ...] + :rtype: None + """ + try: + existing_session = self._outgoing_endpoints[frame[0]] + self._incoming_endpoints[channel] = existing_session + await self._incoming_endpoints[channel]._incoming_begin( # pylint:disable=protected-access + frame + ) + except KeyError: + new_session = Session.from_incoming_frame(self, channel) + self._incoming_endpoints[channel] = new_session + await new_session._incoming_begin(frame) # pylint:disable=protected-access + + async def _incoming_end(self, channel, frame): + # type: (int, Tuple[Any, ...]) -> None + """Process incoming End frame to close a session. + + The incoming frame format is:: + + - frame[0]: error (Optional[AMQPError]) + + :param int channel: The incoming channel number. + :param frame: The incoming End frame. + :type frame: Tuple[Any, ...] + :rtype: None + """ + try: + await self._incoming_endpoints[channel]._incoming_end(frame) # pylint:disable=protected-access + self._incoming_endpoints.pop(channel) + self._outgoing_endpoints.pop(channel) + except KeyError: + #close the connection + await self.close( + error=AMQPError( + condition=ErrorCondition.ConnectionCloseForced, + description="Invalid channel number received" + )) + _LOGGER.error( + "END frame received on invalid channel. Closing connection.", + extra=self._network_trace_params + ) + return + + async def _process_incoming_frame( + self, channel, frame + ): # pylint:disable=too-many-return-statements + # type: (int, Optional[Union[bytes, Tuple[int, Tuple[Any, ...]]]]) -> bool + """Process an incoming frame, either directly or by passing to the necessary Session. + + :param int channel: The channel the frame arrived on. + :param frame: A tuple containing the performative descriptor and the field values of the frame. + This parameter can be None in the case of an empty frame or a socket timeout. + :type frame: Optional[Tuple[int, NamedTuple]] + :rtype: bool + :returns: A boolean to indicate whether more frames in a batch can be processed or whether the + incoming frame has altered the state. If `True` is returned, the state has changed and the batch + should be interrupted. + """ + try: + performative, fields = cast(Union[bytes, Tuple], frame) + except TypeError: + return True # Empty Frame or socket timeout + fields = cast(Tuple[Any, ...], fields) + try: + self._last_frame_received_time = time.time() + if performative == 20: + await self._incoming_endpoints[channel]._incoming_transfer( # pylint:disable=protected-access + fields + ) + return False + if performative == 21: + await self._incoming_endpoints[channel]._incoming_disposition( # pylint:disable=protected-access + fields + ) + return False + if performative == 19: + await self._incoming_endpoints[channel]._incoming_flow( # pylint:disable=protected-access + fields + ) + return False + if performative == 18: + await self._incoming_endpoints[channel]._incoming_attach( # pylint:disable=protected-access + fields + ) + return False + if performative == 22: + await self._incoming_endpoints[channel]._incoming_detach( # pylint:disable=protected-access + fields + ) + return True + if performative == 17: + await self._incoming_begin(channel, fields) + return True + if performative == 23: + await self._incoming_end(channel, fields) + return True + if performative == 16: + await self._incoming_open(channel, fields) + return True + if performative == 24: + await self._incoming_close(channel, fields) + return True + if performative == 0: + await self._incoming_header(channel, cast(bytes, fields)) + return True + if performative == 1: + return False # TODO: incoming EMPTY + _LOGGER.error("Unrecognized incoming frame: %r", frame, extra=self._network_trace_params) + return True + except KeyError: + return True # TODO: channel error + + async def _process_outgoing_frame(self, channel, frame): + # type: (int, NamedTuple) -> None + """Send an outgoing frame if the connection is in a legal state. + + :raises ValueError: If the connection is not open or not in a valid state. + """ + if not self._allow_pipelined_open and self.state in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ]: + raise ValueError("Connection not configured to allow pipeline send.") + if self.state not in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ConnectionState.OPENED, + ]: + raise ValueError("Connection not open.") + now = time.time() + if get_local_timeout( + now, + cast(float, self._idle_timeout), + cast(float, self._last_frame_received_time), + ) or (await self._get_remote_timeout(now)): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) + await self.close( + error=AMQPError( + condition=ErrorCondition.ConnectionCloseForced, + description="No frame received for the idle timeout.", + ), + wait=False, + ) + return + await self._send_frame(channel, frame) + + async def _get_remote_timeout(self, now): + # type: (float) -> bool + """Check whether the local connection has reached the remote endpoints idle timeout since + the last outgoing frame was sent. + + If the time since the last since frame is greater than the allowed idle interval, an Empty + frame will be sent to maintain the connection. + + :param float now: The current time to check against. + :rtype: bool + :returns: Whether the local connection should be shutdown due to timeout. + """ + if self._remote_idle_timeout and self._last_frame_sent_time: + time_since_last_sent = now - self._last_frame_sent_time + if time_since_last_sent > cast(int, self._remote_idle_timeout_send_frame): + await self._outgoing_empty() + return False + + async def _wait_for_response(self, wait, end_state): + # type: (Union[bool, float], ConnectionState) -> None + """Wait for an incoming frame to be processed that will result in a desired state change. + + :param wait: Whether to wait for an incoming frame to be processed. Can be set to `True` to wait + indefinitely, or an int to wait for a specified amount of time (in seconds). To not wait, set to `False`. + :type wait: bool or float + :param ConnectionState end_state: The desired end state to wait until. + :rtype: None + """ + if wait is True: + await self.listen(wait=False) + while self.state != end_state: + await asyncio.sleep(self._idle_wait_time) + await self.listen(wait=False) + elif wait: + await self.listen(wait=False) + timeout = time.time() + wait + while self.state != end_state: + if time.time() >= timeout: + break + await asyncio.sleep(self._idle_wait_time) + await self.listen(wait=False) + + async def listen(self, wait=False, batch=1, **kwargs): + # type: (Union[float, int, bool], int, Any) -> None + """Listen on the socket for incoming frames and process them. + + :param wait: Whether to block on the socket until a frame arrives. If set to `True`, socket will + block indefinitely. Alternatively, if set to a time in seconds, the socket will block for at most + the specified timeout. Default value is `False`, where the socket will block for its configured read + timeout (by default 0.1 seconds). + :type wait: int or float or bool + :param int batch: The number of frames to attempt to read and process before returning. The default value + is 1, i.e. process frames one-at-a-time. A higher value should only be used when a receiver is established + and is processing incoming Transfer frames. + :rtype: None + """ + try: + raise self._error + except TypeError: + pass + try: + if self.state not in _CLOSING_STATES: + now = time.time() + if get_local_timeout( + now, + cast(float, self._idle_timeout), + cast(float, self._last_frame_received_time), + ) or (await self._get_remote_timeout(now)): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) + await self.close( + error=AMQPError( + condition=ErrorCondition.ConnectionCloseForced, + description="No frame received for the idle timeout.", + ), + wait=False, + ) + return + if self.state == ConnectionState.END: + # TODO: check error condition + self._error = AMQPConnectionError( + condition=ErrorCondition.ConnectionCloseForced, + description="Connection was already closed.", + ) + return + for _ in range(batch): + if self._can_read(): + if await self._read_frame(wait=wait, **kwargs): + break + else: + _LOGGER.info( + "Connection cannot read frames in this state: %r", + self.state, + extra=self._network_trace_params + ) + break + except (OSError, IOError, SSLError, socket.error) as exc: + self._error = AMQPConnectionError( + ErrorCondition.SocketError, + description="Can not read frame due to exception: " + str(exc), + error=exc, + ) + + def create_session(self, **kwargs): + # type: (Any) -> Session + """Create a new session within this connection. + + :keyword str name: The name of the connection. If not set a GUID will be generated. + :keyword int next_outgoing_id: The transfer-id of the first transfer id the sender will send. + Default value is 0. + :keyword int incoming_window: The initial incoming-window of the Session. Default value is 1. + :keyword int outgoing_window: The initial outgoing-window of the Session. Default value is 1. + :keyword int handle_max: The maximum handle value that may be used on the session. Default value is 4294967295. + :keyword list(str) offered_capabilities: The extension capabilities the session supports. + :keyword list(str) desired_capabilities: The extension capabilities the session may use if + the endpoint supports it. + :keyword dict properties: Session properties. + :keyword bool allow_pipelined_open: Allow frames to be sent on the connection before a response Open frame + has been received. Default value is that configured for the connection. + :keyword float idle_wait_time: The time in seconds to sleep while waiting for a response from the endpoint. + Default value is that configured for the connection. + :keyword bool network_trace: Whether to log the network traffic of this session. If enabled, frames + will be logged at the logging.INFO level. Default value is that configured for the connection. + """ + assigned_channel = self._get_next_outgoing_channel() + kwargs["allow_pipelined_open"] = self._allow_pipelined_open + kwargs["idle_wait_time"] = self._idle_wait_time + session = Session( + self, + assigned_channel, + network_trace=kwargs.pop("network_trace", self._network_trace), + network_trace_params=dict(self._network_trace_params), + **kwargs, + ) + self._outgoing_endpoints[assigned_channel] = session + return session + + async def open(self, wait=False): + # type: (bool) -> None + """Send an Open frame to start the connection. + + Alternatively, this will be called on entering a Connection context manager. + + :param bool wait: Whether to wait to receive an Open response from the endpoint. Default is `False`. + :raises ValueError: If `wait` is set to `False` and `allow_pipelined_open` is disabled. + :rtype: None + """ + await self._connect() + await self._outgoing_open() + if self.state == ConnectionState.HDR_EXCH: + await self._set_state(ConnectionState.OPEN_SENT) + elif self.state == ConnectionState.HDR_SENT: + await self._set_state(ConnectionState.OPEN_PIPE) + if wait: + await self._wait_for_response(wait, ConnectionState.OPENED) + elif not self._allow_pipelined_open: + raise ValueError( + "Connection has been configured to not allow piplined-open. Please set 'wait' parameter." + ) + + async def close(self, error=None, wait=False): + # type: (Optional[AMQPError], bool) -> None + """Close the connection and disconnect the transport. + + Alternatively this method will be called on exiting a Connection context manager. + + :param ~uamqp.AMQPError error: Optional error information to include in the close request. + :param bool wait: Whether to wait for a service Close response. Default is `False`. + :rtype: None + """ + try: + if self.state in [ + ConnectionState.END, + ConnectionState.CLOSE_SENT, + ConnectionState.DISCARDING, + ]: + return + await self._outgoing_close(error=error) + if error: + self._error = AMQPConnectionError( + condition=error.condition, + description=error.description, + info=error.info, + ) + if self.state == ConnectionState.OPEN_PIPE: + await self._set_state(ConnectionState.OC_PIPE) + elif self.state == ConnectionState.OPEN_SENT: + await self._set_state(ConnectionState.CLOSE_PIPE) + elif error: + await self._set_state(ConnectionState.DISCARDING) + else: + await self._set_state(ConnectionState.CLOSE_SENT) + await self._wait_for_response(wait, ConnectionState.END) + except Exception as exc: # pylint:disable=broad-except + # If error happened during closing, ignore the error and set state to END + _LOGGER.info("An error occurred when closing the connection: %r", exc, extra=self._network_trace_params) + await self._set_state(ConnectionState.END) + finally: + await self._disconnect() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_link_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_link_async.py new file mode 100644 index 0000000000000..96f647b49858d --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_link_async.py @@ -0,0 +1,260 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import Optional +import uuid +import logging + +from ..endpoints import Source, Target +from ..constants import DEFAULT_LINK_CREDIT, SessionState, LinkState, Role, SenderSettleMode, ReceiverSettleMode +from ..performatives import ( + AttachFrame, + DetachFrame, +) + +from ..error import ErrorCondition, AMQPLinkError, AMQPLinkRedirect, AMQPConnectionError + +_LOGGER = logging.getLogger(__name__) + + +class Link(object): # pylint: disable=too-many-instance-attributes + """An AMQP Link. + + This object should not be used directly - instead use one of directional + derivatives: Sender or Receiver. + """ + + def __init__(self, session, handle, name, role, **kwargs): + self.state = LinkState.DETACHED + self.name = name or str(uuid.uuid4()) + self.handle = handle + self.remote_handle = None + self.role = role + source_address = kwargs["source_address"] + target_address = kwargs["target_address"] + self.source = ( + source_address + if isinstance(source_address, Source) + else Source( + address=kwargs["source_address"], + durable=kwargs.get("source_durable"), + expiry_policy=kwargs.get("source_expiry_policy"), + timeout=kwargs.get("source_timeout"), + dynamic=kwargs.get("source_dynamic"), + dynamic_node_properties=kwargs.get("source_dynamic_node_properties"), + distribution_mode=kwargs.get("source_distribution_mode"), + filters=kwargs.get("source_filters"), + default_outcome=kwargs.get("source_default_outcome"), + outcomes=kwargs.get("source_outcomes"), + capabilities=kwargs.get("source_capabilities"), + ) + ) + self.target = ( + target_address + if isinstance(target_address, Target) + else Target( + address=kwargs["target_address"], + durable=kwargs.get("target_durable"), + expiry_policy=kwargs.get("target_expiry_policy"), + timeout=kwargs.get("target_timeout"), + dynamic=kwargs.get("target_dynamic"), + dynamic_node_properties=kwargs.get("target_dynamic_node_properties"), + capabilities=kwargs.get("target_capabilities"), + ) + ) + self.link_credit = kwargs.pop("link_credit", None) or DEFAULT_LINK_CREDIT + self.current_link_credit = self.link_credit + self.send_settle_mode = kwargs.pop("send_settle_mode", SenderSettleMode.Mixed) + self.rcv_settle_mode = kwargs.pop("rcv_settle_mode", ReceiverSettleMode.First) + self.unsettled = kwargs.pop("unsettled", None) + self.incomplete_unsettled = kwargs.pop("incomplete_unsettled", None) + self.initial_delivery_count = kwargs.pop("initial_delivery_count", 0) + self.delivery_count = self.initial_delivery_count + self.received_delivery_id = None + self.max_message_size = kwargs.pop("max_message_size", None) + self.remote_max_message_size = None + self.available = kwargs.pop("available", None) + self.properties = kwargs.pop("properties", None) + self.remote_properties = None + self.offered_capabilities = None + self.desired_capabilities = kwargs.pop("desired_capabilities", None) + + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["amqpLink"] = self.name + self._session = session + self._is_closed = False + self._on_link_state_change = kwargs.get("on_link_state_change") + self._on_attach = kwargs.get("on_attach") + self._error = None + + async def __aenter__(self): + await self.attach() + return self + + async def __aexit__(self, *args): + await self.detach(close=True) + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # check link_create_from_endpoint in C lib + raise NotImplementedError("Pending") # TODO: Assuming we establish all links for now... + + def get_state(self): + try: + raise self._error + except TypeError: + pass + return self.state + + def _check_if_closed(self): + if self._is_closed: + try: + raise self._error + except TypeError: + raise AMQPConnectionError(condition=ErrorCondition.InternalError, description="Link already closed.") + + async def _set_state(self, new_state): + # type: (LinkState) -> None + """Update the session state.""" + if new_state is None: + return + previous_state = self.state + self.state = new_state + _LOGGER.info("Link state changed: %r -> %r", previous_state, new_state, extra=self.network_trace_params) + try: + await self._on_link_state_change(previous_state, new_state) + except TypeError: + pass + except Exception as e: # pylint: disable=broad-except + _LOGGER.error("Link state change callback failed: '%r'", e, extra=self.network_trace_params) + + async def _on_session_state_change(self): + if self._session.state == SessionState.MAPPED: + if not self._is_closed and self.state == LinkState.DETACHED: + await self._outgoing_attach() + await self._set_state(LinkState.ATTACH_SENT) + elif self._session.state == SessionState.DISCARDING: + await self._set_state(LinkState.DETACHED) + + async def _outgoing_attach(self): + self.delivery_count = self.initial_delivery_count + attach_frame = AttachFrame( + name=self.name, + handle=self.handle, + role=self.role, + send_settle_mode=self.send_settle_mode, + rcv_settle_mode=self.rcv_settle_mode, + source=self.source, + target=self.target, + unsettled=self.unsettled, + incomplete_unsettled=self.incomplete_unsettled, + initial_delivery_count=self.initial_delivery_count if self.role == Role.Sender else None, + max_message_size=self.max_message_size, + offered_capabilities=self.offered_capabilities if self.state == LinkState.ATTACH_RCVD else None, + desired_capabilities=self.desired_capabilities if self.state == LinkState.DETACHED else None, + properties=self.properties, + ) + if self.network_trace: + _LOGGER.debug("-> %r", attach_frame, extra=self.network_trace_params) + await self._session._outgoing_attach(attach_frame) # pylint: disable=protected-access + + async def _incoming_attach(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", AttachFrame(*frame), extra=self.network_trace_params) + if self._is_closed: + raise ValueError("Invalid link") + if not frame[5] or not frame[6]: + _LOGGER.info("Cannot get source or target. Detaching link", extra=self.network_trace_params) + await self._set_state(LinkState.DETACHED) + raise ValueError("Invalid link") + self.remote_handle = frame[1] # handle + self.remote_max_message_size = frame[10] # max_message_size + self.offered_capabilities = frame[11] # offered_capabilities + self.remote_properties = frame[13] + if self.state == LinkState.DETACHED: + await self._set_state(LinkState.ATTACH_RCVD) + elif self.state == LinkState.ATTACH_SENT: + await self._set_state(LinkState.ATTACHED) + if self._on_attach: + try: + if frame[5]: + frame[5] = Source(*frame[5]) + if frame[6]: + frame[6] = Target(*frame[6]) + await self._on_attach(AttachFrame(*frame)) + except Exception as e: # pylint: disable=broad-except + _LOGGER.warning("Callback for link attach raised error: %s", e, extra=self.network_trace_params) + + async def _outgoing_flow(self, **kwargs): + flow_frame = { + "handle": self.handle, + "delivery_count": self.delivery_count, + "link_credit": self.current_link_credit, + "available": kwargs.get("available"), + "drain": kwargs.get("drain"), + "echo": kwargs.get("echo"), + "properties": kwargs.get("properties"), + } + await self._session._outgoing_flow(flow_frame) # pylint: disable=protected-access + + async def _incoming_flow(self, frame): + pass + + async def _incoming_disposition(self, frame): + pass + + async def _outgoing_detach(self, close=False, error=None): + detach_frame = DetachFrame(handle=self.handle, closed=close, error=error) + if self.network_trace: + _LOGGER.debug("-> %r", detach_frame, extra=self.network_trace_params) + await self._session._outgoing_detach(detach_frame) # pylint: disable=protected-access + if close: + self._is_closed = True + + async def _incoming_detach(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", DetachFrame(*frame), extra=self.network_trace_params) + if self.state == LinkState.ATTACHED: + await self._outgoing_detach(close=frame[1]) # closed + elif frame[1] and not self._is_closed and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: + # Received a closing detach after we sent a non-closing detach. + # In this case, we MUST signal that we closed by reattaching and then sending a closing detach. + await self._outgoing_attach() + await self._outgoing_detach(close=True) + # TODO: on_detach_hook + if frame[2]: # error + # frame[2][0] is condition, frame[2][1] is description, frame[2][2] is info + error_cls = AMQPLinkRedirect if frame[2][0] == ErrorCondition.LinkRedirect else AMQPLinkError + self._error = error_cls(condition=frame[2][0], description=frame[2][1], info=frame[2][2]) + await self._set_state(LinkState.ERROR) + else: + await self._set_state(LinkState.DETACHED) + + async def attach(self): + if self._is_closed: + raise ValueError("Link already closed.") + await self._outgoing_attach() + await self._set_state(LinkState.ATTACH_SENT) + + async def detach(self, close=False, error=None): + if self.state in (LinkState.DETACHED, LinkState.ERROR): + return + try: + self._check_if_closed() + if self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: + await self._outgoing_detach(close=close, error=error) + await self._set_state(LinkState.DETACHED) + elif self.state == LinkState.ATTACHED: + await self._outgoing_detach(close=close, error=error) + await self._set_state(LinkState.DETACH_SENT) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.info("An error occurred when detaching the link: %r", exc, extra=self.network_trace_params) + await self._set_state(LinkState.DETACHED) + + async def flow(self, *, link_credit: Optional[int] = None, **kwargs) -> None: + self.current_link_credit = link_credit if link_credit is not None else self.link_credit + await self._outgoing_flow(**kwargs) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_link_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_link_async.py new file mode 100644 index 0000000000000..94f3163accfd1 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_link_async.py @@ -0,0 +1,249 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import time +import logging +from functools import partial + +from ..management_link import PendingManagementOperation +from ._sender_async import SenderLink +from ._receiver_async import ReceiverLink +from ..constants import ( + ManagementLinkState, + LinkState, + SenderSettleMode, + ReceiverSettleMode, + ManagementExecuteOperationResult, + ManagementOpenResult, + SEND_DISPOSITION_REJECT, + MessageDeliveryState, + LinkDeliverySettleReason +) +from ..error import AMQPException, ErrorCondition +from ..message import Properties, _MessageDelivery + +_LOGGER = logging.getLogger(__name__) + + +class ManagementLink(object): # pylint:disable=too-many-instance-attributes + """ + # TODO: Fill in docstring + """ + + def __init__(self, session, endpoint, **kwargs): + self.next_message_id = 0 + self.state = ManagementLinkState.IDLE + self._pending_operations = [] + self._session = session + self._network_trace_params = kwargs.get('network_trace_params') + self._request_link: SenderLink = session.create_sender_link( + endpoint, + source_address=endpoint, + on_link_state_change=self._on_sender_state_change, + send_settle_mode=SenderSettleMode.Unsettled, + rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) + ) + self._response_link: ReceiverLink = session.create_receiver_link( + endpoint, + target_address=endpoint, + on_link_state_change=self._on_receiver_state_change, + on_transfer=self._on_message_received, + send_settle_mode=SenderSettleMode.Unsettled, + rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) + ) + self._on_amqp_management_error = kwargs.get("on_amqp_management_error") + self._on_amqp_management_open_complete = kwargs.get("on_amqp_management_open_complete") + + self._status_code_field = kwargs.get("status_code_field", b"statusCode") + self._status_description_field = kwargs.get("status_description_field", b"statusDescription") + + self._sender_connected = False + self._receiver_connected = False + + async def __aenter__(self): + await self.open() + return self + + async def __aexit__(self, *args): + await self.close() + + async def _on_sender_state_change(self, previous_state, new_state): + _LOGGER.info( + "Management link sender state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) + if new_state == previous_state: + return + if self.state == ManagementLinkState.OPENING: + if new_state == LinkState.ATTACHED: + self._sender_connected = True + if self._receiver_connected: + self.state = ManagementLinkState.OPEN + await self._on_amqp_management_open_complete(ManagementOpenResult.OK) + elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + self.state = ManagementLinkState.IDLE + await self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) + elif self.state == ManagementLinkState.OPEN: + if new_state is not LinkState.ATTACHED: + self.state = ManagementLinkState.ERROR + await self._on_amqp_management_error() + elif self.state == ManagementLinkState.CLOSING: + if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + self.state = ManagementLinkState.ERROR + await self._on_amqp_management_error() + elif self.state == ManagementLinkState.ERROR: + # All state transitions shall be ignored. + return + + async def _on_receiver_state_change(self, previous_state, new_state): + _LOGGER.info( + "Management link receiver state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) + if new_state == previous_state: + return + if self.state == ManagementLinkState.OPENING: + if new_state == LinkState.ATTACHED: + self._receiver_connected = True + if self._sender_connected: + self.state = ManagementLinkState.OPEN + await self._on_amqp_management_open_complete(ManagementOpenResult.OK) + elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + self.state = ManagementLinkState.IDLE + await self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) + elif self.state == ManagementLinkState.OPEN: + if new_state is not LinkState.ATTACHED: + self.state = ManagementLinkState.ERROR + await self._on_amqp_management_error() + elif self.state == ManagementLinkState.CLOSING: + if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + self.state = ManagementLinkState.ERROR + await self._on_amqp_management_error() + elif self.state == ManagementLinkState.ERROR: + # All state transitions shall be ignored. + return + + async def _on_message_received(self, _, message): + message_properties = message.properties + correlation_id = message_properties[5] + response_detail = message.application_properties + + status_code = response_detail.get(self._status_code_field) + status_description = response_detail.get(self._status_description_field) + + to_remove_operation = None + for operation in self._pending_operations: + if operation.message.properties.message_id == correlation_id: + to_remove_operation = operation + break + if to_remove_operation: + mgmt_result = ( + ManagementExecuteOperationResult.OK + if 200 <= status_code <= 299 + else ManagementExecuteOperationResult.FAILED_BAD_STATUS + ) + await to_remove_operation.on_execute_operation_complete( + mgmt_result, status_code, status_description, message, response_detail.get(b"error-condition") + ) + self._pending_operations.remove(to_remove_operation) + + async def _on_send_complete(self, message_delivery, reason, state): + if reason == LinkDeliverySettleReason.DISPOSITION_RECEIVED and SEND_DISPOSITION_REJECT in state: + # sample reject state: {'rejected': [[b'amqp:not-allowed', b"Invalid command 'RE1AD'.", None]]} + to_remove_operation = None + for operation in self._pending_operations: + if message_delivery.message == operation.message: + to_remove_operation = operation + break + self._pending_operations.remove(to_remove_operation) + # TODO: better error handling + # AMQPException is too general? to be more specific: MessageReject(Error) or AMQPManagementError? + # or should there an error mapping which maps the condition to the error type + + # The callback is defined in management_operation.py + await to_remove_operation.on_execute_operation_complete( + ManagementExecuteOperationResult.ERROR, + None, + None, + message_delivery.message, + error=AMQPException( + condition=state[SEND_DISPOSITION_REJECT][0][0], # 0 is error condition + description=state[SEND_DISPOSITION_REJECT][0][1], # 1 is error description + info=state[SEND_DISPOSITION_REJECT][0][2], # 2 is error info + ), + ) + + async def open(self): + if self.state != ManagementLinkState.IDLE: + raise ValueError("Management links are already open or opening.") + self.state = ManagementLinkState.OPENING + await self._response_link.attach() + await self._request_link.attach() + + async def execute_operation(self, message, on_execute_operation_complete, **kwargs): + """Execute a request and wait on a response. + + :param message: The message to send in the management request. + :type message: ~uamqp.message.Message + :param on_execute_operation_complete: Callback to be called when the operation is complete. + The following value will be passed to the callback: operation_id, operation_result, status_code, + status_description, raw_message and error. + :type on_execute_operation_complete: Callable[[str, str, int, str, ~uamqp.message.Message, Exception], None] + :keyword operation: The type of operation to be performed. This value will + be service-specific, but common values include READ, CREATE and UPDATE. + This value will be added as an application property on the message. + :paramtype operation: bytes or str + :keyword type: The type on which to carry out the operation. This will + be specific to the entities of the service. This value will be added as + an application property on the message. + :paramtype type: bytes or str + :keyword str locales: A list of locales that the sending peer permits for incoming + informational text in response messages. + :keyword float timeout: Provide an optional timeout in seconds within which a response + to the management request must be received. + :rtype: None + """ + timeout = kwargs.get("timeout") + message.application_properties["operation"] = kwargs.get("operation") + message.application_properties["type"] = kwargs.get("type") + if "locales" in kwargs: + message.application_properties["locales"] = kwargs.get("locales") + try: + # TODO: namedtuple is immutable, which may push us to re-think about the namedtuple approach for Message + new_properties = message.properties._replace(message_id=self.next_message_id) + except AttributeError: + new_properties = Properties(message_id=self.next_message_id) + message = message._replace(properties=new_properties) + expire_time = (time.time() + timeout) if timeout else None + message_delivery = _MessageDelivery(message, MessageDeliveryState.WaitingToBeSent, expire_time) + + on_send_complete = partial(self._on_send_complete, message_delivery) + + await self._request_link.send_transfer(message, on_send_complete=on_send_complete, timeout=timeout) + self.next_message_id += 1 + self._pending_operations.append(PendingManagementOperation(message, on_execute_operation_complete)) + + async def close(self): + if self.state != ManagementLinkState.IDLE: + self.state = ManagementLinkState.CLOSING + await self._response_link.detach(close=True) + await self._request_link.detach(close=True) + for pending_operation in self._pending_operations: + await pending_operation.on_execute_operation_complete( + ManagementExecuteOperationResult.LINK_CLOSED, + None, + None, + pending_operation.message, + AMQPException(condition=ErrorCondition.ClientError, description="Management link already closed."), + ) + self._pending_operations = [] + self.state = ManagementLinkState.IDLE diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_operation_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_operation_async.py new file mode 100644 index 0000000000000..e5830d7d0ff8e --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_management_operation_async.py @@ -0,0 +1,140 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- +import logging +import uuid +import time +from functools import partial + +from ._management_link_async import ManagementLink +from ..error import ( + AMQPLinkError, + ErrorCondition +) + +from ..constants import ( + ManagementOpenResult, + ManagementExecuteOperationResult +) + +_LOGGER = logging.getLogger(__name__) + + +class ManagementOperation(object): + def __init__(self, session, endpoint='$management', **kwargs): + self._mgmt_link_open_status = None + + self._session = session + self._connection = self._session._connection + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } + self._mgmt_link = self._session.create_request_response_link_pair( + endpoint=endpoint, + on_amqp_management_open_complete=self._on_amqp_management_open_complete, + on_amqp_management_error=self._on_amqp_management_error, + **kwargs + ) # type: ManagementLink + self._responses = {} + self._mgmt_error = None + + async def _on_amqp_management_open_complete(self, result): + """Callback run when the send/receive links are open and ready + to process messages. + + :param result: Whether the link opening was successful. + :type result: int + """ + self._mgmt_link_open_status = result + + async def _on_amqp_management_error(self): + """Callback run if an error occurs in the send/receive links.""" + # TODO: This probably shouldn't be ValueError + self._mgmt_error = ValueError("Management Operation error occurred.") + + async def _on_execute_operation_complete( + self, + operation_id, + operation_result, + status_code, + status_description, + raw_message, + error=None + ): + _LOGGER.debug( + "Management operation completed, id: %r; result: %r; code: %r; description: %r, error: %r", + operation_id, + operation_result, + status_code, + status_description, + error, + extra=self._network_trace_params + ) + + if operation_result in\ + (ManagementExecuteOperationResult.ERROR, ManagementExecuteOperationResult.LINK_CLOSED): + self._mgmt_error = error + _LOGGER.error( + "Failed to complete management operation due to error: %r.", + error, + extra=self._network_trace_params + ) + else: + self._responses[operation_id] = (status_code, status_description, raw_message) + + async def execute(self, message, operation=None, operation_type=None, timeout=0): + start_time = time.time() + operation_id = str(uuid.uuid4()) + self._responses[operation_id] = None + self._mgmt_error = None + + await self._mgmt_link.execute_operation( + message, + partial(self._on_execute_operation_complete, operation_id), + timeout=timeout, + operation=operation, + type=operation_type + ) + + while not self._responses[operation_id] and not self._mgmt_error: + if timeout and timeout > 0: + now = time.time() + if (now - start_time) >= timeout: + raise TimeoutError("Failed to receive mgmt response in {}ms".format(timeout)) + await self._connection.listen() + + if self._mgmt_error: + self._responses.pop(operation_id) + raise self._mgmt_error # pylint: disable=raising-bad-type + + response = self._responses.pop(operation_id) + return response + + async def open(self): + self._mgmt_link_open_status = ManagementOpenResult.OPENING + await self._mgmt_link.open() + + async def ready(self): + try: + raise self._mgmt_error # pylint: disable=raising-bad-type + except TypeError: + pass + + if self._mgmt_link_open_status == ManagementOpenResult.OPENING: + return False + if self._mgmt_link_open_status == ManagementOpenResult.OK: + return True + # ManagementOpenResult.ERROR or CANCELLED + # TODO: update below with correct status code + info + raise AMQPLinkError( + condition=ErrorCondition.ClientError, + description="Failed to open mgmt link, management link status: {}".format(self._mgmt_link_open_status), + info=None + ) + + async def close(self): + await self._mgmt_link.close() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_receiver_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_receiver_async.py new file mode 100644 index 0000000000000..7d3c6c5401606 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_receiver_async.py @@ -0,0 +1,124 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import uuid +import logging +from typing import Optional, Union + +from .._decode import decode_payload +from ._link_async import Link +from ..constants import LinkState, Role +from ..performatives import ( + TransferFrame, + DispositionFrame, +) +from ..outcomes import Received, Accepted, Rejected, Released, Modified + + +_LOGGER = logging.getLogger(__name__) + + +class ReceiverLink(Link): + def __init__(self, session, handle, source_address, **kwargs): + name = kwargs.pop("name", None) or str(uuid.uuid4()) + role = Role.Receiver + if "target_address" not in kwargs: + kwargs["target_address"] = "receiver-link-{}".format(name) + super(ReceiverLink, self).__init__(session, handle, name, role, source_address=source_address, **kwargs) + self._on_transfer = kwargs.pop("on_transfer") + self._received_payload = bytearray() + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # TODO: Assuming we establish all links for now... + # check link_create_from_endpoint in C lib + raise NotImplementedError("Pending") + + async def _process_incoming_message(self, frame, message): + try: + return await self._on_transfer(frame, message) + except Exception as e: # pylint: disable=broad-except + _LOGGER.error("Transfer callback function failed with error: %r", e, extra=self.network_trace_params) + return None + + async def _incoming_attach(self, frame): + await super(ReceiverLink, self)._incoming_attach(frame) + if frame[9] is None: # initial_delivery_count + _LOGGER.info("Cannot get initial-delivery-count. Detaching link", extra=self.network_trace_params) + await self._set_state(LinkState.DETACHED) # TODO: Send detach now? + self.delivery_count = frame[9] + self.current_link_credit = self.link_credit + await self._outgoing_flow() + + async def _incoming_transfer(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", TransferFrame(payload=b"***", *frame[:-1]), extra=self.network_trace_params) + self.current_link_credit -= 1 + self.delivery_count += 1 + self.received_delivery_id = frame[1] # delivery_id + if not self.received_delivery_id and not self._received_payload: + pass # TODO: delivery error + if self._received_payload or frame[5]: # more + self._received_payload.extend(frame[11]) + if not frame[5]: + if self._received_payload: + message = decode_payload(memoryview(self._received_payload)) + self._received_payload = bytearray() + else: + message = decode_payload(frame[11]) + delivery_state = await self._process_incoming_message(frame, message) + if not frame[4] and delivery_state: # settled + await self._outgoing_disposition( + first=frame[1], + last=frame[1], + settled=True, + state=delivery_state, + batchable=None + ) + + async def _wait_for_response(self, wait: Union[bool, float]) -> None: + if wait is True: + await self._session._connection.listen(wait=False) # pylint: disable=protected-access + if self.state == LinkState.ERROR: + raise self._error + elif wait: + await self._session._connection.listen(wait=wait) # pylint: disable=protected-access + if self.state == LinkState.ERROR: + raise self._error + + async def _outgoing_disposition( + self, + first: int, + last: Optional[int], + settled: Optional[bool], + state: Optional[Union[Received, Accepted, Rejected, Released, Modified]], + batchable: Optional[bool], + ): + disposition_frame = DispositionFrame( + role=self.role, first=first, last=last, settled=settled, state=state, batchable=batchable + ) + if self.network_trace: + _LOGGER.debug("-> %r", DispositionFrame(*disposition_frame), extra=self.network_trace_params) + await self._session._outgoing_disposition(disposition_frame) # pylint: disable=protected-access + + async def attach(self): + await super().attach() + self._received_payload = bytearray() + + async def send_disposition( + self, + *, + wait: Union[bool, float] = False, + first_delivery_id: int, + last_delivery_id: Optional[int] = None, + settled: Optional[bool] = None, + delivery_state: Optional[Union[Received, Accepted, Rejected, Released, Modified]] = None, + batchable: Optional[bool] = None + ): + if self._is_closed: + raise ValueError("Link already closed.") + await self._outgoing_disposition(first_delivery_id, last_delivery_id, settled, delivery_state, batchable) + await self._wait_for_response(wait) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sasl_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sasl_async.py new file mode 100644 index 0000000000000..441eb40ec8744 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sasl_async.py @@ -0,0 +1,149 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from ._transport_async import AsyncTransport, WebSocketTransportAsync +from ..constants import SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT +from .._transport import AMQPS_PORT +from ..performatives import SASLInit + + +_SASL_FRAME_TYPE = b"\x01" + + +# TODO: do we need it here? it's a duplicate of the sync version +class SASLPlainCredential(object): + """PLAIN SASL authentication mechanism. + See https://tools.ietf.org/html/rfc4616 for details + """ + + mechanism = b"PLAIN" + + def __init__(self, authcid, passwd, authzid=None): + self.authcid = authcid + self.passwd = passwd + self.authzid = authzid + + def start(self): + if self.authzid: + login_response = self.authzid.encode("utf-8") + else: + login_response = b"" + login_response += b"\0" + login_response += self.authcid.encode("utf-8") + login_response += b"\0" + login_response += self.passwd.encode("utf-8") + return login_response + + +# TODO: do we need it here? it's a duplicate of the sync version +class SASLAnonymousCredential(object): + """ANONYMOUS SASL authentication mechanism. + See https://tools.ietf.org/html/rfc4505 for details + """ + + mechanism = b"ANONYMOUS" + + def start(self): # pylint: disable=no-self-use + return b"" + + +# TODO: do we need it here? it's a duplicate of the sync version +class SASLExternalCredential(object): + """EXTERNAL SASL mechanism. + Enables external authentication, i.e. not handled through this protocol. + Only passes 'EXTERNAL' as authentication mechanism, but no further + authentication data. + """ + + mechanism = b"EXTERNAL" + + def start(self): # pylint: disable=no-self-use + return b"" + + +class SASLTransportMixinAsync: # pylint: disable=no-member + async def _negotiate(self): + await self.write(SASL_HEADER_FRAME) + _, returned_header = await self.receive_frame() + if returned_header[1] != SASL_HEADER_FRAME: + raise ValueError( + f"""Mismatching AMQP header protocol. Expected: {SASL_HEADER_FRAME!r},""" + """received: {returned_header[1]!r}""" + ) + + _, supported_mechanisms = await self.receive_frame(verify_frame_type=1) + if ( + self.credential.mechanism not in supported_mechanisms[1][0] + ): # sasl_server_mechanisms + raise ValueError( + "Unsupported SASL credential type: {}".format(self.credential.mechanism) + ) + sasl_init = SASLInit( + mechanism=self.credential.mechanism, + initial_response=self.credential.start(), + hostname=self.host, + ) + await self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) + + _, next_frame = await self.receive_frame(verify_frame_type=1) + frame_type, fields = next_frame + if frame_type != 0x00000044: # SASLOutcome + raise NotImplementedError("Unsupported SASL challenge") + if fields[0] == SASLCode.Ok: # code + return + raise ValueError( + "SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields) + ) + + +class SASLTransport(AsyncTransport, SASLTransportMixinAsync): + def __init__( + self, + host, + credential, + *, + port=AMQPS_PORT, + connect_timeout=None, + ssl_opts=None, + **kwargs, + ): + self.credential = credential + ssl_opts = ssl_opts or True + super(SASLTransport, self).__init__( + host, + port=port, + connect_timeout=connect_timeout, + ssl_opts=ssl_opts, + **kwargs, + ) + + async def negotiate(self): + await self._negotiate() + + +class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync): + def __init__( + self, + host, + credential, + *, + port=WEBSOCKET_PORT, + connect_timeout=None, + ssl_opts=None, + **kwargs, + ): + self.credential = credential + ssl_opts = ssl_opts or True + super().__init__( + host, + port=port, + connect_timeout=connect_timeout, + ssl_opts=ssl_opts, + **kwargs, + ) + + async def negotiate(self): + await self._negotiate() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sender_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sender_async.py new file mode 100644 index 0000000000000..29a4c052baa31 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_sender_async.py @@ -0,0 +1,203 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import struct +import uuid +import logging +import time +import asyncio + +from .._encode import encode_payload +from ._link_async import Link +from ..constants import SessionTransferState, LinkDeliverySettleReason, LinkState, Role, SenderSettleMode, SessionState +from ..error import AMQPLinkError, ErrorCondition, MessageException + +_LOGGER = logging.getLogger(__name__) + + +class PendingDelivery(object): + def __init__(self, **kwargs): + self.message = kwargs.get("message") + self.sent = False + self.frame = None + self.on_delivery_settled = kwargs.get("on_delivery_settled") + self.start = time.time() + self.transfer_state = None + self.timeout = kwargs.get("timeout") + self.settled = kwargs.get("settled", False) + self._network_trace_params = kwargs.get('network_trace_params') + + async def on_settled(self, reason, state): + if self.on_delivery_settled and not self.settled: + try: + await self.on_delivery_settled(reason, state) + except Exception as e: # pylint:disable=broad-except + _LOGGER.warning( + "Message 'on_send_complete' callback failed: %r", + e, + extra=self._network_trace_params + ) + self.settled = True + + +class SenderLink(Link): + def __init__(self, session, handle, target_address, **kwargs): + name = kwargs.pop("name", None) or str(uuid.uuid4()) + role = Role.Sender + if "source_address" not in kwargs: + kwargs["source_address"] = "sender-link-{}".format(name) + super(SenderLink, self).__init__(session, handle, name, role, target_address=target_address, **kwargs) + self._pending_deliveries = [] + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # TODO: Assuming we establish all links for now... + # check link_create_from_endpoint in C lib + raise NotImplementedError("Pending") + + # In theory we should not need to purge pending deliveries on attach/dettach - as a link should + # be resume-able, however this is not yet supported. + async def _incoming_attach(self, frame): + try: + await super(SenderLink, self)._incoming_attach(frame) + except AMQPLinkError: + await self._remove_pending_deliveries() + raise + self.current_link_credit = self.link_credit + await self._outgoing_flow() + await self.update_pending_deliveries() + + async def _incoming_detach(self, frame): + await super(SenderLink, self)._incoming_detach(frame) + await self._remove_pending_deliveries() + + async def _incoming_flow(self, frame): + rcv_link_credit = frame[6] # link_credit + rcv_delivery_count = frame[5] # delivery_count + if frame[4] is not None: # handle + if rcv_link_credit is None or rcv_delivery_count is None: + _LOGGER.info( + "Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.", + extra=self.network_trace_params + ) + await self._remove_pending_deliveries() + await self._set_state(LinkState.DETACHED) # TODO: Send detach now? + else: + self.current_link_credit = rcv_delivery_count + rcv_link_credit - self.delivery_count + await self.update_pending_deliveries() + + async def _outgoing_transfer(self, delivery): + output = bytearray() + encode_payload(output, delivery.message) + delivery_count = self.delivery_count + 1 + delivery.frame = { + "handle": self.handle, + "delivery_tag": struct.pack(">I", abs(delivery_count)), + "message_format": delivery.message._code, # pylint:disable=protected-access + "settled": delivery.settled, + "more": False, + "rcv_settle_mode": None, + "state": None, + "resume": None, + "aborted": None, + "batchable": None, + "payload": output, + } + await self._session._outgoing_transfer( # pylint:disable=protected-access + delivery, + self.network_trace_params if self.network_trace else None + ) + sent_and_settled = False + if delivery.transfer_state == SessionTransferState.OKAY: + self.delivery_count = delivery_count + self.current_link_credit -= 1 + delivery.sent = True + if delivery.settled: + await delivery.on_settled(LinkDeliverySettleReason.SETTLED, None) + sent_and_settled = True + # elif delivery.transfer_state == SessionTransferState.ERROR: + # TODO: Session wasn't mapped yet - re-adding to the outgoing delivery queue? + return sent_and_settled + + async def _incoming_disposition(self, frame): + if not frame[3]: # settled + return + range_end = (frame[2] or frame[1]) + 1 # first or last + settled_ids = list(range(frame[1], range_end)) + unsettled = [] + for delivery in self._pending_deliveries: + if delivery.sent and delivery.frame["delivery_id"] in settled_ids: + await delivery.on_settled(LinkDeliverySettleReason.DISPOSITION_RECEIVED, frame[4]) # state + continue + unsettled.append(delivery) + self._pending_deliveries = unsettled + + async def _remove_pending_deliveries(self): + futures = [] + for delivery in self._pending_deliveries: + futures.append(asyncio.ensure_future(delivery.on_settled(LinkDeliverySettleReason.NOT_DELIVERED, None))) + await asyncio.gather(*futures) + self._pending_deliveries = [] + + async def _on_session_state_change(self): + if self._session.state == SessionState.DISCARDING: + await self._remove_pending_deliveries() + await super()._on_session_state_change() + + async def update_pending_deliveries(self): + if self.current_link_credit <= 0: + self.current_link_credit = self.link_credit + await self._outgoing_flow() + now = time.time() + pending = [] + for delivery in self._pending_deliveries: + if delivery.timeout and (now - delivery.start) >= delivery.timeout: + delivery.on_settled(LinkDeliverySettleReason.TIMEOUT, None) + continue + if not delivery.sent: + sent_and_settled = await self._outgoing_transfer(delivery) + if sent_and_settled: + continue + pending.append(delivery) + self._pending_deliveries = pending + + async def send_transfer(self, message, *, send_async=False, **kwargs): + self._check_if_closed() + if self.state != LinkState.ATTACHED: + raise AMQPLinkError( + condition=ErrorCondition.ClientError, + description="Link is not attached." + ) + settled = self.send_settle_mode == SenderSettleMode.Settled + if self.send_settle_mode == SenderSettleMode.Mixed: + settled = kwargs.pop("settled", True) + delivery = PendingDelivery( + on_delivery_settled=kwargs.get("on_send_complete"), + timeout=kwargs.get("timeout"), + message=message, + settled=settled, + network_trace_params=self.network_trace_params + ) + if self.current_link_credit == 0 or send_async: + self._pending_deliveries.append(delivery) + else: + sent_and_settled = await self._outgoing_transfer(delivery) + if not sent_and_settled: + self._pending_deliveries.append(delivery) + return delivery + + async def cancel_transfer(self, delivery): + try: + index = self._pending_deliveries.index(delivery) + except ValueError: + raise ValueError("Found no matching pending transfer.") + delivery = self._pending_deliveries[index] + if delivery.sent: + raise MessageException( + ErrorCondition.ClientError, + message="Transfer cannot be cancelled. Message has already been sent and awaiting disposition.", + ) + await delivery.on_settled(LinkDeliverySettleReason.CANCELLED, None) + self._pending_deliveries.pop(index) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_session_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_session_async.py new file mode 100644 index 0000000000000..a13a228fdd1a7 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_session_async.py @@ -0,0 +1,460 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations +import uuid +import logging +import time +import asyncio +from typing import Optional, Union, List + +from ..constants import ConnectionState, SessionState, SessionTransferState, Role +from ._sender_async import SenderLink +from ._receiver_async import ReceiverLink +from ._management_link_async import ManagementLink +from ..performatives import ( + BeginFrame, + EndFrame, + FlowFrame, + TransferFrame, + DispositionFrame, +) +from .._encode import encode_frame +from ..error import AMQPError, ErrorCondition + +_LOGGER = logging.getLogger(__name__) + + +class Session(object): # pylint: disable=too-many-instance-attributes + """ + :param int remote_channel: The remote channel for this Session. + :param int next_outgoing_id: The transfer-id of the first transfer id the sender will send. + :param int incoming_window: The initial incoming-window of the sender. + :param int outgoing_window: The initial outgoing-window of the sender. + :param int handle_max: The maximum handle value that may be used on the Session. + :param list(str) offered_capabilities: The extension capabilities the sender supports. + :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports + :param dict properties: Session properties. + """ + + def __init__(self, connection, channel, **kwargs): + self.name = kwargs.pop("name", None) or str(uuid.uuid4()) + self.state = SessionState.UNMAPPED + self.handle_max = kwargs.get("handle_max", 4294967295) + self.properties = kwargs.pop("properties", None) + self.remote_properties = None + self.channel = channel + self.remote_channel = None + self.next_outgoing_id = kwargs.pop("next_outgoing_id", 0) + self.next_incoming_id = None + self.incoming_window = kwargs.pop("incoming_window", 1) + self.outgoing_window = kwargs.pop("outgoing_window", 1) + self.target_incoming_window = self.incoming_window + self.remote_incoming_window = 0 + self.remote_outgoing_window = 0 + self.offered_capabilities = None + self.desired_capabilities = kwargs.pop("desired_capabilities", None) + + self.allow_pipelined_open = kwargs.pop("allow_pipelined_open", True) + self.idle_wait_time = kwargs.get("idle_wait_time", 0.1) + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["amqpSession"] = self.name + + self.links = {} + self._connection = connection + self._output_handles = {} + self._input_handles = {} + + async def __aenter__(self): + await self.begin() + return self + + async def __aexit__(self, *args): + await self.end() + + @classmethod + def from_incoming_frame(cls, connection, channel): + # check session_create_from_endpoint in C lib + new_session = cls(connection, channel) + return new_session + + async def _set_state(self, new_state): + # type: (SessionState) -> None + """Update the session state.""" + if new_state is None: + return + previous_state = self.state + self.state = new_state + _LOGGER.info( + "Session state changed: %r -> %r", + previous_state, + new_state, + extra=self.network_trace_params, + ) + for link in self.links.values(): + await link._on_session_state_change() # pylint: disable=protected-access + + async def _on_connection_state_change(self): + if self._connection.state in [ConnectionState.CLOSE_RCVD, ConnectionState.END]: + if self.state not in [SessionState.DISCARDING, SessionState.UNMAPPED]: + await self._set_state(SessionState.DISCARDING) + + def _get_next_output_handle(self): + # type: () -> int + """Get the next available outgoing handle number within the max handle limit. + + :raises ValueError: If maximum handle has been reached. + :returns: The next available outgoing handle number. + :rtype: int + """ + if len(self._output_handles) >= self.handle_max: + raise ValueError("Maximum number of handles ({}) has been reached.".format(self.handle_max)) + next_handle = next(i for i in range(1, self.handle_max) if i not in self._output_handles) + return next_handle + + async def _outgoing_begin(self): + begin_frame = BeginFrame( + remote_channel=self.remote_channel if self.state == SessionState.BEGIN_RCVD else None, + next_outgoing_id=self.next_outgoing_id, + outgoing_window=self.outgoing_window, + incoming_window=self.incoming_window, + handle_max=self.handle_max, + offered_capabilities=self.offered_capabilities if self.state == SessionState.BEGIN_RCVD else None, + desired_capabilities=self.desired_capabilities if self.state == SessionState.UNMAPPED else None, + properties=self.properties, + ) + if self.network_trace: + _LOGGER.debug("-> %r", begin_frame, extra=self.network_trace_params) + await self._connection._process_outgoing_frame(self.channel, begin_frame) # pylint: disable=protected-access + + async def _incoming_begin(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", BeginFrame(*frame), extra=self.network_trace_params) + self.handle_max = frame[4] # handle_max + self.next_incoming_id = frame[1] # next_outgoing_id + self.remote_incoming_window = frame[2] # incoming_window + self.remote_outgoing_window = frame[3] # outgoing_window + self.remote_properties = frame[7] + if self.state == SessionState.BEGIN_SENT: + self.remote_channel = frame[0] # remote_channel + await self._set_state(SessionState.MAPPED) + elif self.state == SessionState.UNMAPPED: + await self._set_state(SessionState.BEGIN_RCVD) + await self._outgoing_begin() + await self._set_state(SessionState.MAPPED) + + async def _outgoing_end(self, error=None): + end_frame = EndFrame(error=error) + if self.network_trace: + _LOGGER.debug("-> %r", end_frame, extra=self.network_trace_params) + await self._connection._process_outgoing_frame(self.channel, end_frame) # pylint: disable=protected-access + + async def _incoming_end(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", EndFrame(*frame), extra=self.network_trace_params) + if self.state not in [ + SessionState.END_RCVD, + SessionState.END_SENT, + SessionState.DISCARDING, + ]: + await self._set_state(SessionState.END_RCVD) + for _, link in self.links.items(): + await link.detach() + # TODO: handling error + await self._outgoing_end() + await self._set_state(SessionState.UNMAPPED) + + async def _outgoing_attach(self, frame): + await self._connection._process_outgoing_frame(self.channel, frame) # pylint: disable=protected-access + + async def _incoming_attach(self, frame): + try: + self._input_handles[frame[1]] = self.links[frame[0].decode("utf-8")] # name and handle + await self._input_handles[frame[1]]._incoming_attach(frame) # pylint: disable=protected-access + except KeyError: + try: + outgoing_handle = self._get_next_output_handle() + except ValueError: + _LOGGER.error( + "Unable to attach new link - cannot allocate more handles.", + extra=self.network_trace_params + ) + # detach the link that would have been set. + await self.links[frame[0].decode("utf-8")].detach( + error=AMQPError( + condition=ErrorCondition.LinkDetachForced, + description=f"Cannot allocate more handles, the max number of handles is {self.handle_max}. Detaching link", # pylint: disable=line-too-long + info=None, + ) + ) + return + if frame[2] == Role.Sender: + new_link = ReceiverLink.from_incoming_frame(self, outgoing_handle, frame) + else: + new_link = SenderLink.from_incoming_frame(self, outgoing_handle, frame) + await new_link._incoming_attach(frame) # pylint: disable=protected-access + self.links[frame[0]] = new_link + self._output_handles[outgoing_handle] = new_link + self._input_handles[frame[1]] = new_link + except ValueError as e: + # Reject Link + _LOGGER.error( + "Unable to attach new link: %r", + e, + extra=self.network_trace_params + ) + await self._input_handles[frame[1]].detach() + + async def _outgoing_flow(self, frame=None): + link_flow = frame or {} + link_flow.update( + { + "next_incoming_id": self.next_incoming_id, + "incoming_window": self.incoming_window, + "next_outgoing_id": self.next_outgoing_id, + "outgoing_window": self.outgoing_window, + } + ) + flow_frame = FlowFrame(**link_flow) + if self.network_trace: + _LOGGER.debug("-> %r", flow_frame, extra=self.network_trace_params) + await self._connection._process_outgoing_frame(self.channel, flow_frame) # pylint: disable=protected-access + + async def _incoming_flow(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", FlowFrame(*frame), extra=self.network_trace_params) + self.next_incoming_id = frame[2] # next_outgoing_id + remote_incoming_id = frame[0] or self.next_outgoing_id # next_incoming_id TODO "initial-outgoing-id" + self.remote_incoming_window = remote_incoming_id + frame[1] - self.next_outgoing_id # incoming_window + self.remote_outgoing_window = frame[3] # outgoing_window + if frame[4] is not None: # handle + await self._input_handles[frame[4]]._incoming_flow(frame) # pylint: disable=protected-access + else: + for link in self._output_handles.values(): + if self.remote_incoming_window > 0 and not link._is_closed: # pylint: disable=protected-access + await link._incoming_flow(frame) # pylint: disable=protected-access + + async def _outgoing_transfer(self, delivery, network_trace_params): + if self.state != SessionState.MAPPED: + delivery.transfer_state = SessionTransferState.ERROR + if self.remote_incoming_window <= 0: + delivery.transfer_state = SessionTransferState.BUSY + else: + payload = delivery.frame["payload"] + payload_size = len(payload) + + delivery.frame["delivery_id"] = self.next_outgoing_id + # calculate the transfer frame encoding size excluding the payload + delivery.frame["payload"] = b"" + # TODO: encoding a frame would be expensive, we might want to improve depending on the perf test results + encoded_frame = encode_frame(TransferFrame(**delivery.frame))[1] + transfer_overhead_size = len(encoded_frame) + + # available size for payload per frame is calculated as following: + # remote max frame size - transfer overhead (calculated) - header (8 bytes) + available_frame_size = ( + self._connection._remote_max_frame_size - transfer_overhead_size - 8 # pylint: disable=protected-access + ) + + start_idx = 0 + remaining_payload_cnt = payload_size + # encode n-1 frames if payload_size > available_frame_size + while remaining_payload_cnt > available_frame_size: + tmp_delivery_frame = { + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": True, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "delivery_id": self.next_outgoing_id, + } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) + await self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, + TransferFrame( + payload=payload[start_idx : start_idx + available_frame_size], + **tmp_delivery_frame + ) + ) + start_idx += available_frame_size + remaining_payload_cnt -= available_frame_size + + # encode the last frame + tmp_delivery_frame = { + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": False, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "delivery_id": self.next_outgoing_id, + } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) + await self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, + TransferFrame(payload=payload[start_idx:], **tmp_delivery_frame) + ) + self.next_outgoing_id += 1 + self.remote_incoming_window -= 1 + self.outgoing_window -= 1 + # TODO: We should probably handle an error at the connection and update state accordingly + delivery.transfer_state = SessionTransferState.OKAY + + async def _incoming_transfer(self, frame): + self.next_incoming_id += 1 + self.remote_outgoing_window -= 1 + self.incoming_window -= 1 + try: + await self._input_handles[frame[0]]._incoming_transfer(frame) # pylint: disable=protected-access + except KeyError: + _LOGGER.error( + "Received Transfer frame on unattached link. Ending session.", + extra=self.network_trace_params + ) + await self._set_state(SessionState.DISCARDING) + await self.end( + error=AMQPError( + condition=ErrorCondition.SessionUnattachedHandle, + description="""Invalid handle reference in received frame: """ + """Handle is not currently associated with an attached link""", + ) + ) + if self.incoming_window == 0: + self.incoming_window = self.target_incoming_window + await self._outgoing_flow() + + async def _outgoing_disposition(self, frame): + await self._connection._process_outgoing_frame(self.channel, frame) # pylint: disable=protected-access + + async def _incoming_disposition(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", DispositionFrame(*frame), extra=self.network_trace_params) + for link in self._input_handles.values(): + await link._incoming_disposition(frame) # pylint: disable=protected-access + + async def _outgoing_detach(self, frame): + await self._connection._process_outgoing_frame(self.channel, frame) # pylint: disable=protected-access + + async def _incoming_detach(self, frame): + try: + link = self._input_handles[frame[0]] # handle + await link._incoming_detach(frame) # pylint: disable=protected-access + # if link._is_closed: TODO + # self.links.pop(link.name, None) + # self._input_handles.pop(link.remote_handle, None) + # self._output_handles.pop(link.handle, None) + except KeyError: + await self._set_state(SessionState.DISCARDING) + await self._connection.close( + error=AMQPError( + condition=ErrorCondition.SessionUnattachedHandle, + description="""Invalid handle reference in received frame: """ + """Handle is not currently associated with an attached link""", + ) + ) + + async def _wait_for_response(self, wait, end_state): + # type: (Union[bool, float], SessionState) -> None + if wait is True: + await self._connection.listen(wait=False) + while self.state != end_state: + await asyncio.sleep(self.idle_wait_time) + await self._connection.listen(wait=False) + elif wait: + await self._connection.listen(wait=False) + timeout = time.time() + wait + while self.state != end_state: + if time.time() >= timeout: + break + await asyncio.sleep(self.idle_wait_time) + await self._connection.listen(wait=False) + + async def begin(self, wait=False): + await self._outgoing_begin() + await self._set_state(SessionState.BEGIN_SENT) + if wait: + await self._wait_for_response(wait, SessionState.BEGIN_SENT) + elif not self.allow_pipelined_open: + raise ValueError("Connection has been configured to not allow piplined-open. Please set 'wait' parameter.") + + async def end(self, error=None, wait=False): + # type: (Optional[AMQPError], bool) -> None + try: + if self.state not in [SessionState.UNMAPPED, SessionState.DISCARDING]: + await self._outgoing_end(error=error) + for _, link in self.links.items(): + await link.detach() + new_state = SessionState.DISCARDING if error else SessionState.END_SENT + await self._set_state(new_state) + await self._wait_for_response(wait, SessionState.UNMAPPED) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.info("An error occurred when ending the session: %r", exc, extra=self.network_trace_params) + await self._set_state(SessionState.UNMAPPED) + + def create_receiver_link(self, source_address, **kwargs): + assigned_handle = self._get_next_output_handle() + link = ReceiverLink( + self, + handle=assigned_handle, + source_address=source_address, + network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), + **kwargs, + ) + self.links[link.name] = link + self._output_handles[assigned_handle] = link + return link + + def create_sender_link(self, target_address, **kwargs): + assigned_handle = self._get_next_output_handle() + link = SenderLink( + self, + handle=assigned_handle, + target_address=target_address, + network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), + **kwargs, + ) + self._output_handles[assigned_handle] = link + self.links[link.name] = link + return link + + def create_request_response_link_pair(self, endpoint, **kwargs): + return ManagementLink( + self, + endpoint, + network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), + **kwargs, + ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py new file mode 100644 index 0000000000000..69bd96527e845 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py @@ -0,0 +1,547 @@ +# ------------------------------------------------------------------------- # pylint: disable=file-needs-copyright-header +# This is a fork of the transport.py which was originally written by Barry Pederson and +# maintained by the Celery project: https://github.com/celery/py-amqp. +# +# Copyright (C) 2009 Barry Pederson +# +# The license text can also be found here: +# http://www.opensource.org/licenses/BSD-3-Clause +# +# License +# ======= +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ------------------------------------------------------------------------- + +import asyncio +import errno +import socket +import ssl +import struct +from ssl import SSLError +from io import BytesIO +import logging + + + +import certifi + +from .._platform import KNOWN_TCP_OPTS, SOL_TCP +from .._encode import encode_frame +from .._decode import decode_frame, decode_empty_frame +from ..constants import ( + DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS, + TLS_HEADER_FRAME, + WEBSOCKET_PORT, + AMQP_WS_SUBPROTOCOL, + TIMEOUT_INTERVAL, +) +from .._transport import ( + AMQP_FRAME, + get_errno, + to_host_port, + DEFAULT_SOCKET_SETTINGS, + SIGNED_INT_MAX, + _UNAVAIL, + AMQP_PORT, +) +from ..error import AuthenticationException, ErrorCondition + + +_LOGGER = logging.getLogger(__name__) + + +class AsyncTransportMixin: + async def receive_frame(self, timeout=None, **kwargs): + try: + header, channel, payload = await asyncio.wait_for( + self.read(**kwargs), timeout=timeout + ) + if not payload: + decoded = decode_empty_frame(header) + else: + decoded = decode_frame(payload) + return channel, decoded + except ( + TimeoutError, + socket.timeout, + asyncio.IncompleteReadError, + asyncio.TimeoutError, + ): + return None, None + + async def read(self, verify_frame_type=0): + async with self.socket_lock: + read_frame_buffer = BytesIO() + try: + frame_header = memoryview(bytearray(8)) + read_frame_buffer.write( + await self._read(8, buffer=frame_header, initial=True) + ) + + channel = struct.unpack(">H", frame_header[6:])[0] + size = frame_header[0:4] + if size == AMQP_FRAME: # Empty frame or AMQP header negotiation + return frame_header, channel, None + size = struct.unpack(">I", size)[0] + offset = frame_header[4] + frame_type = frame_header[5] + if verify_frame_type is not None and frame_type != verify_frame_type: + _LOGGER.debug( + "Received invalid frame type: %r, expected: %r", + frame_type, + verify_frame_type, + extra=self.network_trace_params + ) + raise ValueError( + f"Received invalid frame type: {frame_type}, expected: {verify_frame_type}" + ) + # >I is an unsigned int, but the argument to sock.recv is signed, + # so we know the size can be at most 2 * SIGNED_INT_MAX + payload_size = size - len(frame_header) + payload = memoryview(bytearray(payload_size)) + if size > SIGNED_INT_MAX: + read_frame_buffer.write( + await self._read(SIGNED_INT_MAX, buffer=payload) + ) + read_frame_buffer.write( + await self._read( + size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:] + ) + ) + else: + read_frame_buffer.write( + await self._read(payload_size, buffer=payload) + ) + except ( + asyncio.CancelledError, + asyncio.TimeoutError, + TimeoutError, + socket.timeout, + asyncio.IncompleteReadError + ): + read_frame_buffer.write(self._read_buffer.getvalue()) + self._read_buffer = read_frame_buffer + self._read_buffer.seek(0) + raise + except (OSError, IOError, SSLError, socket.error) as exc: + # Don't disconnect for ssl read time outs + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and "timed out" in str(exc): + raise socket.timeout() + if get_errno(exc) not in _UNAVAIL: + self.connected = False + _LOGGER.debug("Transport read failed: %r", exc, extra=self.network_trace_params) + raise + offset -= 2 + return frame_header, channel, payload[offset:] + + async def write(self, s): + async with self.socket_lock: + try: + await self._write(s) + except socket.timeout: + raise + except (OSError, IOError, socket.error) as exc: + _LOGGER.debug("Transport write failed: %r", exc, extra=self.network_trace_params) + if get_errno(exc) not in _UNAVAIL: + self.connected = False + raise + + async def send_frame(self, channel, frame, **kwargs): + header, performative = encode_frame(frame, **kwargs) + if performative is None: + data = header + else: + encoded_channel = struct.pack(">H", channel) + data = header + encoded_channel + performative + + await self.write(data) + + def _build_ssl_opts(self, sslopts): + if sslopts in [True, False, None, {}]: + return sslopts + try: + if "context" in sslopts: + return self._build_ssl_context(**sslopts.pop("context")) + ssl_version = sslopts.get("ssl_version") + if ssl_version is None: + ssl_version = ssl.PROTOCOL_TLS + + # Set SNI headers if supported + server_hostname = sslopts.get("server_hostname") + if ( + (server_hostname is not None) + and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) + and (hasattr(ssl, "SSLContext")) + ): + context = ssl.SSLContext(ssl_version) + cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED) + certfile = sslopts.get("certfile") + keyfile = sslopts.get("keyfile") + context.verify_mode = cert_reqs + if cert_reqs != ssl.CERT_NONE: + context.check_hostname = True + if (certfile is not None) and (keyfile is not None): + context.load_cert_chain(certfile, keyfile) + return context + ca_certs = sslopts.get("ca_certs") + if ca_certs: + context = ssl.SSLContext(ssl_version) + context.load_verify_locations(ca_certs) + return context + return True + except TypeError: + raise TypeError( + "SSL configuration must be a dictionary, or the value True." + ) + + def _build_ssl_context( + self, check_hostname=None, **ctx_options + ): # pylint: disable=no-self-use + ctx = ssl.create_default_context(**ctx_options) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(cafile=certifi.where()) + ctx.check_hostname = check_hostname + return ctx + + +class AsyncTransport( + AsyncTransportMixin +): # pylint: disable=too-many-instance-attributes + """Common superclass for TCP and SSL transports.""" + + def __init__( + self, + host, + *, + port=AMQP_PORT, + connect_timeout=None, + ssl_opts=False, + socket_settings=None, + raise_on_initial_eintr=True, + **kwargs, # pylint: disable=unused-argument + ): + self.connected = False + self.sock = None + self.reader = None + self.writer = None + self.raise_on_initial_eintr = raise_on_initial_eintr + self._read_buffer = BytesIO() + self.host, self.port = to_host_port(host, port) + + self.connect_timeout = connect_timeout + self.socket_settings = socket_settings + self.socket_lock = asyncio.Lock() + self.sslopts = ssl_opts + self.network_trace_params = kwargs.get('network_trace_params') + + async def connect(self): + try: + # are we already connected? + if self.connected: + return + try: + # Building ssl opts here instead of constructor, so that invalid cert error is raised + # when client is connecting, rather then during creation. For uamqp exception parity. + self.sslopts = self._build_ssl_opts(self.sslopts) + except FileNotFoundError as exc: + # FileNotFoundError does not have missing filename info, so adding it below. + # Assuming that this must be ca_certs, since this is the only file path that + # users can pass in (`connection_verify` in the EH/SB clients) through sslopts above. + # For uamqp exception parity. Remove later when resolving issue #27128. + exc.filename = self.sslopts + raise exc + self.reader, self.writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.sslopts, + family=socket.AF_UNSPEC, + proto=SOL_TCP, + server_hostname=self.host if self.sslopts else None, + ) + self.connected = True + sock = self.writer.transport.get_extra_info("socket") + if sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self._set_socket_options(sock, self.socket_settings) + + + except (OSError, IOError, SSLError) as e: + _LOGGER.info("Transport connect failed: %r", e, extra=self.network_trace_params) + # if not fully connected, close socket, and reraise error + if self.writer and not self.connected: + self.writer.close() + await self.writer.wait_closed() + self.connected = False + raise + + def _get_tcp_socket_defaults(self, sock): # pylint: disable=no-self-use + tcp_opts = {} + for opt in KNOWN_TCP_OPTS: + enum = None + if opt == "TCP_USER_TIMEOUT": + try: + from socket import TCP_USER_TIMEOUT as enum + except ImportError: + # should be in Python 3.6+ on Linux. + enum = 18 + elif hasattr(socket, opt): + enum = getattr(socket, opt) + + if enum: + if opt in DEFAULT_SOCKET_SETTINGS: + tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt] + elif hasattr(socket, opt): + tcp_opts[enum] = sock.getsockopt(SOL_TCP, getattr(socket, opt)) + return tcp_opts + + def _set_socket_options(self, sock, socket_settings): + tcp_opts = self._get_tcp_socket_defaults(sock) + if socket_settings: + tcp_opts.update(socket_settings) + for opt, val in tcp_opts.items(): + sock.setsockopt(SOL_TCP, opt, val) + + async def _read( + self, + toread, + initial=False, + buffer=None, + _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR), + ): + # According to SSL_read(3), it can at most return 16kb of data. + # Thus, we use an internal read buffer like TCPTransport._read + # to get the exact number of bytes wanted. + length = 0 + view = buffer or memoryview(bytearray(toread)) + nbytes = self._read_buffer.readinto(view) + toread -= nbytes + length += nbytes + try: + while toread: + try: + view[nbytes : nbytes + toread] = await self.reader.readexactly( + toread + ) + nbytes = toread + except AttributeError: + # This means that close() was called concurrently + # self.reader has been set to None. + raise IOError("Connection has already been closed") + except asyncio.IncompleteReadError as exc: + pbytes = len(exc.partial) + view[nbytes : nbytes + pbytes] = exc.partial + nbytes = pbytes + except socket.error as exc: + # ssl.sock.read may cause a SSLerror without errno + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and "timed out" in str(exc): + raise socket.timeout() + # errno 110 is equivalent to ETIMEDOUT on linux non blocking sockets, when a keep alive is set, + # and is set when the connection to the server doesnt succeed + # https://man7.org/linux/man-pages/man7/tcp.7.html. + # This behavior is linux specific and only on async. sync Linux & async/sync Windows & Mac raised + # ConnectionAborted or ConnectionReset errors which properly end up in a retry loop. + if exc.errno in [110]: + raise ConnectionAbortedError('The connection was closed abruptly.') + # ssl.sock.read may cause ENOENT if the + # operation couldn't be performed (Issue celery#1414). + if exc.errno in _errnos: + if initial and self.raise_on_initial_eintr: + raise socket.timeout() + continue + raise + if not nbytes: + raise IOError("Server unexpectedly closed connection") + + length += nbytes + toread -= nbytes + except: # noqa + self._read_buffer = BytesIO(view[:length]) + raise + return view + + async def _write(self, s): + """Write a string out to the SSL socket fully.""" + try: + self.writer.write(s) + await self.writer.drain() + except AttributeError: + raise IOError("Connection has already been closed") + + async def close(self): + async with self.socket_lock: + try: + if self.writer is not None: + # Closing the writer closes the underlying socket. + self.writer.close() + if self.sslopts: + # see issue: https://github.com/encode/httpx/issues/914 + await asyncio.sleep(0) + self.writer.transport.abort() + await self.writer.wait_closed() + except Exception as e: # pylint: disable=broad-except + # Sometimes SSL raises APPLICATION_DATA_AFTER_CLOSE_NOTIFY here on close. + _LOGGER.debug("Error shutting down socket: %r", e, extra=self.network_trace_params) + self.writer, self.reader = None, None + self.connected = False + + async def negotiate(self): + if not self.sslopts: + return + await self.write(TLS_HEADER_FRAME) + _, returned_header = await self.receive_frame(verify_frame_type=None) + if returned_header[1] == TLS_HEADER_FRAME: + raise ValueError( + f"""Mismatching TLS header protocol. Expected: {TLS_HEADER_FRAME!r},""" + """received: {returned_header[1]!r}""" + ) + + +class WebSocketTransportAsync( + AsyncTransportMixin +): # pylint: disable=too-many-instance-attributes + def __init__( + self, + host, + *, + port=WEBSOCKET_PORT, + connect_timeout=None, + ssl_opts=None, + **kwargs + ): + self._read_buffer = BytesIO() + self.socket_lock = asyncio.Lock() + self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else None + self._connect_timeout = connect_timeout or TIMEOUT_INTERVAL + self._custom_endpoint = kwargs.get("custom_endpoint") + self.host, self.port = to_host_port(host, port) + self.ws = None + self.session = None + self._http_proxy = kwargs.get("http_proxy", None) + self.connected = False + self.network_trace_params = kwargs.get('network_trace_params') + + async def connect(self): + self.sslopts = self._build_ssl_opts(self.sslopts) + username, password = None, None + http_proxy_host, http_proxy_port = None, None + http_proxy_auth = None + + if self._http_proxy: + http_proxy_host = self._http_proxy["proxy_hostname"] + http_proxy_port = self._http_proxy["proxy_port"] + if http_proxy_host and http_proxy_port: + http_proxy_host = f"{http_proxy_host}:{http_proxy_port}" + username = self._http_proxy.get("username", None) + password = self._http_proxy.get("password", None) + + try: + from aiohttp import ClientSession, ClientConnectorError + from urllib.parse import urlsplit + except ImportError: + raise ImportError( + "Please install aiohttp library to use async websocket transport." + ) + + if username or password: + from aiohttp import BasicAuth + + http_proxy_auth = BasicAuth(login=username, password=password) + + self.session = ClientSession() + if self._custom_endpoint: + url = f"wss://{self._custom_endpoint}" + else: + url = f"wss://{self.host}" + parsed_url = urlsplit(url) + url = f"{parsed_url.scheme}://{parsed_url.netloc}:{self.port}{parsed_url.path}" + + try: + # Enabling heartbeat that sends a ping message every n seconds and waits for pong response. + # if pong response is not received then close connection. This raises an error when trying + # to communicate with the websocket which is no longer active. + # We are waiting a bug fix in aiohttp for these 2 bugs where aiohttp ws might hang on network disconnect + # and the heartbeat mechanism helps mitigate these two. + # https://github.com/aio-libs/aiohttp/pull/5860 + # https://github.com/aio-libs/aiohttp/issues/2309 + + self.ws = await self.session.ws_connect( + url=url, + timeout=self._connect_timeout, + protocols=[AMQP_WS_SUBPROTOCOL], + autoclose=False, + proxy=http_proxy_host, + proxy_auth=http_proxy_auth, + ssl=self.sslopts, + heartbeat=DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS, + ) + except ClientConnectorError as exc: + _LOGGER.info("Websocket connect failed: %r", exc, extra=self.network_trace_params) + if self._custom_endpoint: + raise AuthenticationException( + ErrorCondition.ClientError, + description="Failed to authenticate the connection due to exception: " + str(exc), + error=exc, + ) + raise ConnectionError("Failed to establish websocket connection: " + str(exc)) + self.connected = True + + async def _read(self, toread, buffer=None, **kwargs): # pylint: disable=unused-argument + """Read exactly n bytes from the peer.""" + length = 0 + view = buffer or memoryview(bytearray(toread)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + toread -= nbytes + try: + while toread: + data = await self.ws.receive_bytes() + read_length = len(data) + if read_length <= toread: + view[length : length + read_length] = data + toread -= read_length + length += read_length + else: + view[length : length + toread] = data[0:toread] + self._read_buffer = BytesIO(data[toread:]) + toread = 0 + return view + except: + self._read_buffer = BytesIO(view[:length]) + raise + + async def close(self): + """Do any preliminary work in shutting down the connection.""" + async with self.socket_lock: + await self.ws.close() + await self.session.close() + self.connected = False + + async def _write(self, s): + """Completely write a string (byte array) to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + await self.ws.send_bytes(s) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/authentication.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/authentication.py new file mode 100644 index 0000000000000..3e794139986f5 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/authentication.py @@ -0,0 +1,175 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#------------------------------------------------------------------------- + +import time +from collections import namedtuple +from functools import partial + +from .sasl import SASLAnonymousCredential, SASLPlainCredential +from .utils import generate_sas_token + +from .constants import ( + AUTH_DEFAULT_EXPIRATION_SECONDS, + TOKEN_TYPE_JWT, + TOKEN_TYPE_SASTOKEN, + AUTH_TYPE_CBS, + AUTH_TYPE_SASL_PLAIN +) + +AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) + + +def _generate_sas_access_token(auth_uri, sas_name, sas_key, expiry_in=AUTH_DEFAULT_EXPIRATION_SECONDS): + expires_on = int(time.time() + expiry_in) + token = generate_sas_token(auth_uri, sas_name, sas_key, expires_on) + return AccessToken( + token, + expires_on + ) + + +class SASLPlainAuth(object): + # TODO: + # 1. naming decision, suffix with Auth vs Credential + auth_type = AUTH_TYPE_SASL_PLAIN + + def __init__(self, authcid, passwd, authzid=None): + self.sasl = SASLPlainCredential(authcid, passwd, authzid) + + +class _CBSAuth(object): + # TODO: + # 1. naming decision, suffix with Auth vs Credential + auth_type = AUTH_TYPE_CBS + + def __init__( + self, + uri, + audience, + token_type, + get_token, + **kwargs + ): + """ + CBS authentication using JWT tokens. + + :param uri: The AMQP endpoint URI. This must be provided as + a decoded string. + :type uri: str + :param audience: The token audience field. For SAS tokens + this is usually the URI. + :type audience: str + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :type get_token: callable object + :param token_type: The type field of the token request. + Default value is `"jwt"`. + :type token_type: str + + """ + self.sasl = SASLAnonymousCredential() + self.uri = uri + self.audience = audience + self.token_type = token_type + self.get_token = get_token + self.expires_in = kwargs.pop("expires_in", AUTH_DEFAULT_EXPIRATION_SECONDS) + self.expires_on = kwargs.pop("expires_on", None) + + @staticmethod + def _set_expiry(expires_in, expires_on): + if not expires_on and not expires_in: + raise ValueError("Must specify either 'expires_on' or 'expires_in'.") + if not expires_on: + expires_on = time.time() + expires_in + else: + expires_in = expires_on - time.time() + if expires_in < 1: + raise ValueError("Token has already expired.") + return expires_in, expires_on + + +class JWTTokenAuth(_CBSAuth): + # TODO: + # 1. naming decision, suffix with Auth vs Credential + def __init__( + self, + uri, + audience, + get_token, + **kwargs + ): + """ + CBS authentication using JWT tokens. + + :param uri: The AMQP endpoint URI. This must be provided as + a decoded string. + :type uri: str + :param audience: The token audience field. For SAS tokens + this is usually the URI. + :type audience: str + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :type get_token: callable object + :param token_type: The type field of the token request. + Default value is `"jwt"`. + :type token_type: str + + """ + super(JWTTokenAuth, self).__init__(uri, audience, kwargs.pop("token_type", TOKEN_TYPE_JWT), get_token) + self.get_token = get_token + + +class SASTokenAuth(_CBSAuth): + # TODO: + # 1. naming decision, suffix with Auth vs Credential + def __init__( + self, + uri, + audience, + username, + password, + **kwargs + ): + """ + CBS authentication using SAS tokens. + + :param uri: The AMQP endpoint URI. This must be provided as + a decoded string. + :type uri: str + :param audience: The token audience field. For SAS tokens + this is usually the URI. + :type audience: str + :param username: The SAS token username, also referred to as the key + name or policy name. This can optionally be encoded into the URI. + :type username: str + :param password: The SAS token password, also referred to as the key. + This can optionally be encoded into the URI. + :type password: str + :param expires_in: The total remaining seconds until the token + expires. + :type expires_in: int + :param expires_on: The timestamp at which the SAS token will expire + formatted as seconds since epoch. + :type expires_on: float + :param token_type: The type field of the token request. + Default value is `"servicebus.windows.net:sastoken"`. + :type token_type: str + + """ + self.username = username + self.password = password + expires_in = kwargs.pop("expires_in", AUTH_DEFAULT_EXPIRATION_SECONDS) + expires_on = kwargs.pop("expires_on", None) + expires_in, expires_on = self._set_expiry(expires_in, expires_on) + self.get_token = partial(_generate_sas_access_token, uri, username, password, expires_in) + super(SASTokenAuth, self).__init__( + uri, + audience, + kwargs.pop("token_type", TOKEN_TYPE_SASTOKEN), + self.get_token, + expires_in=expires_in, + expires_on=expires_on + ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/cbs.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/cbs.py new file mode 100644 index 0000000000000..bef81f1792d65 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/cbs.py @@ -0,0 +1,299 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import logging +from datetime import datetime +from typing import Optional + +from .utils import utc_now, utc_from_timestamp +from .management_link import ManagementLink +from .message import Message, Properties +from .error import ( + AuthenticationException, + ErrorCondition, + TokenAuthFailure, + TokenExpired, +) +from .constants import ( + CbsState, + CbsAuthState, + CBS_PUT_TOKEN, + CBS_EXPIRATION, + CBS_NAME, + CBS_TYPE, + CBS_OPERATION, + ManagementExecuteOperationResult, + ManagementOpenResult, +) + +_LOGGER = logging.getLogger(__name__) + + +def check_expiration_and_refresh_status(expires_on, refresh_window): + seconds_since_epoc = int(utc_now().timestamp()) + is_expired = seconds_since_epoc >= expires_on + is_refresh_required = (expires_on - seconds_since_epoc) <= refresh_window + return is_expired, is_refresh_required + + +def check_put_timeout_status(auth_timeout, token_put_time): + if auth_timeout > 0: + return (int(utc_now().timestamp()) - token_put_time) >= auth_timeout + return False + + +class CBSAuthenticator(object): # pylint:disable=too-many-instance-attributes + def __init__(self, session, auth, **kwargs): + self._session = session + self._connection = self._session._connection + self._mgmt_link = self._session.create_request_response_link_pair( + endpoint="$cbs", + on_amqp_management_open_complete=self._on_amqp_management_open_complete, + on_amqp_management_error=self._on_amqp_management_error, + status_code_field=b"status-code", + status_description_field=b"status-description", + ) # type: ManagementLink + + if not auth.get_token or not callable(auth.get_token): + raise ValueError("get_token must be a callable object.") + + self._auth = auth + self._encoding = "UTF-8" + self._auth_timeout = kwargs.get("auth_timeout") + self._token_put_time = None + self._expires_on = None + self._token = None + self._refresh_window = None + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } + + self._token_status_code = None + self._token_status_description = None + + self.state = CbsState.CLOSED + self.auth_state = CbsAuthState.IDLE + + def _put_token(self, token: str, token_type: str, audience: str, expires_on: Optional[datetime] = None) -> None: + message = Message( # type: ignore # TODO: missing positional args header, etc. + value=token, + properties=Properties(message_id=self._mgmt_link.next_message_id), # type: ignore + application_properties={ + CBS_NAME: audience, + CBS_OPERATION: CBS_PUT_TOKEN, + CBS_TYPE: token_type, + CBS_EXPIRATION: expires_on, + }, + ) + self._mgmt_link.execute_operation( + message, + self._on_execute_operation_complete, + timeout=self._auth_timeout, + operation=CBS_PUT_TOKEN, + type=token_type, + ) + self._mgmt_link.next_message_id += 1 + + def _on_amqp_management_open_complete(self, management_open_result): + if self.state in (CbsState.CLOSED, CbsState.ERROR): + _LOGGER.debug( + "CSB with status: %r encounters unexpected AMQP management open complete.", + self.state, + extra=self._network_trace_params + ) + elif self.state == CbsState.OPEN: + self.state = CbsState.ERROR + _LOGGER.info( + "Unexpected AMQP management open complete in OPEN, CBS error occurred.", + extra=self._network_trace_params + ) + elif self.state == CbsState.OPENING: + self.state = ( + CbsState.OPEN + if management_open_result == ManagementOpenResult.OK + else CbsState.CLOSED + ) + _LOGGER.debug( + "CBS completed opening with status: %r", + management_open_result, + extra=self._network_trace_params + ) + + def _on_amqp_management_error(self): + if self.state == CbsState.CLOSED: + _LOGGER.info("Unexpected AMQP error in CLOSED state.", extra=self._network_trace_params) + elif self.state == CbsState.OPENING: + self.state = CbsState.ERROR + self._mgmt_link.close() + _LOGGER.info( + "CBS failed to open with status: %r", + ManagementOpenResult.ERROR, + extra=self._network_trace_params + ) + elif self.state == CbsState.OPEN: + self.state = CbsState.ERROR + _LOGGER.info("CBS error occurred.", extra=self._network_trace_params) + + def _on_execute_operation_complete( + self, + execute_operation_result, + status_code, + status_description, + _, + error_condition=None, + ): + if error_condition: + _LOGGER.info( + "CBS Put token error: %r", + error_condition, + extra=self._network_trace_params + ) + self.auth_state = CbsAuthState.ERROR + return + _LOGGER.debug( + "CBS Put token result (%r), status code: %s, status_description: %s.", + execute_operation_result, + status_code, + status_description, + extra=self._network_trace_params + ) + self._token_status_code = status_code + self._token_status_description = status_description + + if execute_operation_result == ManagementExecuteOperationResult.OK: + self.auth_state = CbsAuthState.OK + elif execute_operation_result == ManagementExecuteOperationResult.ERROR: + self.auth_state = CbsAuthState.ERROR + # put-token-message sending failure, rejected + self._token_status_code = 0 + self._token_status_description = "Auth message has been rejected." + elif ( + execute_operation_result + == ManagementExecuteOperationResult.FAILED_BAD_STATUS + ): + self.auth_state = CbsAuthState.ERROR + + def _update_status(self): + if ( + self.auth_state == CbsAuthState.OK + or self.auth_state == CbsAuthState.REFRESH_REQUIRED + ): + is_expired, is_refresh_required = check_expiration_and_refresh_status( + self._expires_on, self._refresh_window + ) + _LOGGER.debug( + "CBS status check: state == %r, expired == %r, refresh required == %r", + self.auth_state, + is_expired, + is_refresh_required, + extra=self._network_trace_params + ) + if is_expired: + self.auth_state = CbsAuthState.EXPIRED + elif is_refresh_required: + self.auth_state = CbsAuthState.REFRESH_REQUIRED + elif self.auth_state == CbsAuthState.IN_PROGRESS: + _LOGGER.debug( + "CBS update in progress. Token put time: %r", + self._token_put_time, + extra=self._network_trace_params + ) + put_timeout = check_put_timeout_status( + self._auth_timeout, self._token_put_time + ) + if put_timeout: + self.auth_state = CbsAuthState.TIMEOUT + + def _cbs_link_ready(self): + if self.state == CbsState.OPEN: + return True + if self.state != CbsState.OPEN: + return False + if self.state in (CbsState.CLOSED, CbsState.ERROR): + raise TokenAuthFailure( + status_code=ErrorCondition.ClientError, + status_description="CBS authentication link is in broken status, please recreate the cbs link.", + ) + + def open(self): + self.state = CbsState.OPENING + self._mgmt_link.open() + + def close(self): + self._mgmt_link.close() + self.state = CbsState.CLOSED + + def update_token(self): + self.auth_state = CbsAuthState.IN_PROGRESS + access_token = self._auth.get_token() + if not access_token: + _LOGGER.info( + "Token refresh function received an empty token object.", + extra=self._network_trace_params + ) + elif not access_token.token: + _LOGGER.info( + "Token refresh function received an empty token.", + extra=self._network_trace_params + ) + self._expires_on = access_token.expires_on + expires_in = self._expires_on - int(utc_now().timestamp()) + self._refresh_window = int(float(expires_in) * 0.1) + try: + self._token = access_token.token.decode() + except AttributeError: + self._token = access_token.token + try: + token_type = self._auth.token_type.decode() + except AttributeError: + token_type = self._auth.token_type + + self._token_put_time = int(utc_now().timestamp()) + self._put_token( + self._token, + token_type, + self._auth.audience, + utc_from_timestamp(self._expires_on), + ) + + def handle_token(self): + if not self._cbs_link_ready(): + return False + self._update_status() + if self.auth_state == CbsAuthState.IDLE: + self.update_token() + return False + if self.auth_state == CbsAuthState.IN_PROGRESS: + return False + if self.auth_state == CbsAuthState.OK: + return True + if self.auth_state == CbsAuthState.REFRESH_REQUIRED: + _LOGGER.info( + "Token will expire soon - attempting to refresh.", + extra=self._network_trace_params + ) + self.update_token() + return False + if self.auth_state == CbsAuthState.FAILURE: + raise AuthenticationException( + condition=ErrorCondition.InternalError, + description="Failed to open CBS authentication link.", + ) + if self.auth_state == CbsAuthState.ERROR: + raise TokenAuthFailure( + self._token_status_code, + self._token_status_description, + encoding=self._encoding, # TODO: drop off all the encodings + ) + if self.auth_state == CbsAuthState.TIMEOUT: + raise TimeoutError("Authentication attempt timed-out.") + if self.auth_state == CbsAuthState.EXPIRED: + raise TokenExpired( + condition=ErrorCondition.InternalError, + description="CBS Authentication Expired.", + ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py new file mode 100644 index 0000000000000..befddc886ec6f --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py @@ -0,0 +1,1058 @@ +# ------------------------------------------------------------------------- # pylint: disable=client-suffix-needed +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# pylint: disable=client-accepts-api-version-keyword +# pylint: disable=missing-client-constructor-parameter-credential +# pylint: disable=client-method-missing-type-annotations +# pylint: disable=too-many-lines +# TODO: Check types of kwargs (issue exists for this) +import logging +import threading +import queue +import time +import uuid +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union, overload, cast +import certifi +from typing_extensions import Literal + +from ._connection import Connection +from .message import _MessageDelivery +from .error import ( + AMQPException, + ErrorCondition, + MessageException, + MessageSendFailed, + RetryPolicy, + AMQPError, +) +from .outcomes import Received, Rejected, Released, Accepted, Modified + +from .constants import ( + MAX_CHANNELS, + MessageDeliveryState, + SenderSettleMode, + ReceiverSettleMode, + LinkDeliverySettleReason, + TransportType, + SEND_DISPOSITION_ACCEPT, + SEND_DISPOSITION_REJECT, + AUTH_TYPE_CBS, + MAX_FRAME_SIZE_BYTES, + INCOMING_WINDOW, + OUTGOING_WINDOW, + DEFAULT_AUTH_TIMEOUT, + MESSAGE_DELIVERY_DONE_STATES, +) + +from .management_operation import ManagementOperation +from .cbs import CBSAuthenticator + +Outcomes = Union[Received, Rejected, Released, Accepted, Modified] + + +_logger = logging.getLogger(__name__) + + +class AMQPClient( + object +): # pylint: disable=too-many-instance-attributes + """An AMQP client. + :param hostname: The AMQP endpoint to connect to. + :type hostname: str + :keyword auth: Authentication for the connection. This should be one of the following: + - pyamqp.authentication.SASLAnonymous + - pyamqp.authentication.SASLPlain + - pyamqp.authentication.SASTokenAuth + - pyamqp.authentication.JWTTokenAuth + If no authentication is supplied, SASLAnnoymous will be used by default. + :paramtype auth: ~pyamqp.authentication + :keyword client_name: The name for the client, also known as the Container ID. + If no name is provided, a random GUID will be used. + :paramtype client_name: str or bytes + :keyword network_trace: Whether to turn on network trace logs. If `True`, trace logs + will be logged at INFO level. Default is `False`. + :paramtype network_trace: bool + :keyword retry_policy: A policy for parsing errors on link, connection and message + disposition to determine whether the error should be retryable. + :paramtype retry_policy: ~pyamqp.error.RetryPolicy + :keyword keep_alive_interval: If set, a thread will be started to keep the connection + alive during periods of user inactivity. The value will determine how long the + thread will sleep (in seconds) between pinging the connection. If 0 or None, no + thread will be started. + :paramtype keep_alive_interval: int + :keyword max_frame_size: Maximum AMQP frame size. Default is 63488 bytes. + :paramtype max_frame_size: int + :keyword channel_max: Maximum number of Session channels in the Connection. + :paramtype channel_max: int + :keyword idle_timeout: Timeout in seconds after which the Connection will close + if there is no further activity. + :paramtype idle_timeout: int + :keyword auth_timeout: Timeout in seconds for CBS authentication. Otherwise this value will be ignored. + Default value is 60s. + :paramtype auth_timeout: int + :keyword properties: Connection properties. + :paramtype properties: dict[str, any] + :keyword remote_idle_timeout_empty_frame_send_ratio: Portion of the idle timeout time to wait before sending an + empty frame. The default portion is 50% of the idle timeout value (i.e. `0.5`). + :paramtype remote_idle_timeout_empty_frame_send_ratio: float + :keyword incoming_window: The size of the allowed window for incoming messages. + :paramtype incoming_window: int + :keyword outgoing_window: The size of the allowed window for outgoing messages. + :paramtype outgoing_window: int + :keyword handle_max: The maximum number of concurrent link handles. + :paramtype handle_max: int + :keyword on_attach: A callback function to be run on receipt of an ATTACH frame. + The function must take 4 arguments: source, target, properties and error. + :paramtype on_attach: func[ + ~pyamqp.endpoint.Source, ~pyamqp.endpoint.Target, dict, ~pyamqp.error.AMQPConnectionError] + :keyword send_settle_mode: The mode by which to settle message send + operations. If set to `Unsettled`, the client will wait for a confirmation + from the service that the message was successfully sent. If set to 'Settled', + the client will not wait for confirmation and assume success. + :paramtype send_settle_mode: ~pyamqp.constants.SenderSettleMode + :keyword receive_settle_mode: The mode by which to settle message receive + operations. If set to `PeekLock`, the receiver will lock a message once received until + the client accepts or rejects the message. If set to `ReceiveAndDelete`, the service + will assume successful receipt of the message and clear it from the queue. The + default is `PeekLock`. + :paramtype receive_settle_mode: ~pyamqp.constants.ReceiverSettleMode + :keyword desired_capabilities: The extension capabilities desired from the peer endpoint. + :paramtype desired_capabilities: list[bytes] + :keyword max_message_size: The maximum allowed message size negotiated for the Link. + :paramtype max_message_size: int + :keyword link_properties: Metadata to be sent in the Link ATTACH frame. + :paramtype link_properties: dict[str, any] + :keyword link_credit: The Link credit that determines how many + messages the Link will attempt to handle per connection iteration. + The default is 300. + :paramtype link_credit: int + :keyword transport_type: The type of transport protocol that will be used for communicating with + the service. Default is `TransportType.Amqp` in which case port 5671 is used. + If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could + be used instead which uses port 443 for communication. + :paramtype transport_type: ~pyamqp.constants.TransportType + :keyword http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. + :paramtype http_proxy: dict[str, str] + :keyword custom_endpoint_address: The custom endpoint address to use for establishing a connection to + the service, allowing network requests to be routed through any application gateways or + other paths needed for the host environment. Default is None. + If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. + :paramtype custom_endpoint_address: str + :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to + authenticate the identity of the connection endpoint. + Default is None in which case `certifi.where()` will be used. + :paramtype connection_verify: str + """ + + def __init__(self, hostname, **kwargs): + # I think these are just strings not instances of target or source + self._hostname = hostname + self._auth = kwargs.pop("auth", None) + self._name = kwargs.pop("client_name", str(uuid.uuid4())) + self._shutdown = False + self._connection = None + self._session = None + self._link = None + self._socket_timeout = False + self._external_connection = False + self._cbs_authenticator = None + self._auth_timeout = kwargs.pop("auth_timeout", DEFAULT_AUTH_TIMEOUT) + self._mgmt_links = {} + self._mgmt_link_lock = threading.Lock() + self._retry_policy = kwargs.pop("retry_policy", RetryPolicy()) + self._keep_alive_interval = int(kwargs.get("keep_alive_interval", 0)) + self._keep_alive_thread = None + + # Connection settings + self._max_frame_size = kwargs.pop("max_frame_size", MAX_FRAME_SIZE_BYTES) + self._channel_max = kwargs.pop("channel_max", MAX_CHANNELS) + self._idle_timeout = kwargs.pop("idle_timeout", None) + self._properties = kwargs.pop("properties", None) + self._remote_idle_timeout_empty_frame_send_ratio = kwargs.pop( + "remote_idle_timeout_empty_frame_send_ratio", None + ) + self._network_trace = kwargs.pop("network_trace", False) + self._network_trace_params = {"amqpConnection": None, "amqpSession": None, "amqpLink": None} + + # Session settings + self._outgoing_window = kwargs.pop("outgoing_window", OUTGOING_WINDOW) + self._incoming_window = kwargs.pop("incoming_window", INCOMING_WINDOW) + self._handle_max = kwargs.pop("handle_max", None) + + # Link settings + self._send_settle_mode = kwargs.pop( + "send_settle_mode", SenderSettleMode.Unsettled + ) + self._receive_settle_mode = kwargs.pop( + "receive_settle_mode", ReceiverSettleMode.Second + ) + self._desired_capabilities = kwargs.pop("desired_capabilities", None) + self._on_attach = kwargs.pop("on_attach", None) + + # transport + if ( + kwargs.get("transport_type") is TransportType.Amqp + and kwargs.get("http_proxy") is not None + ): + raise ValueError( + "Http proxy settings can't be passed if transport_type is explicitly set to Amqp" + ) + self._transport_type = kwargs.pop("transport_type", TransportType.Amqp) + self._http_proxy = kwargs.pop("http_proxy", None) + + # Custom Endpoint + self._custom_endpoint_address = kwargs.get("custom_endpoint_address") + self._connection_verify = kwargs.get("connection_verify") + + def __enter__(self): + """Run Client in a context manager.""" + self.open() + return self + + def __exit__(self, *args): + """Close and destroy Client on exiting a context manager.""" + self.close() + + def _keep_alive(self): + start_time = time.time() + try: + while self._connection and not self._shutdown: + current_time = time.time() + elapsed_time = current_time - start_time + if elapsed_time >= self._keep_alive_interval: + _logger.debug("Keeping %r connection alive.", self.__class__.__name__) + self._connection.listen(wait=self._socket_timeout, batch=self._link.current_link_credit) + start_time = current_time + time.sleep(1) + except Exception as e: # pylint: disable=broad-except + _logger.debug("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e) + + def _client_ready(self): # pylint: disable=no-self-use + """Determine whether the client is ready to start sending and/or + receiving messages. To be ready, the connection must be open and + authentication complete. + + :rtype: bool + """ + return True + + def _client_run(self, **kwargs): + """Perform a single Connection iteration.""" + self._connection.listen(wait=self._socket_timeout, **kwargs) + + def _close_link(self): + if self._link and not self._link._is_closed: # pylint: disable=protected-access + self._link.detach(close=True) + self._link = None + + def _do_retryable_operation(self, operation, *args, **kwargs): + retry_settings = self._retry_policy.configure_retries() + retry_active = True + absolute_timeout = kwargs.pop("timeout", 0) or 0 + start_time = time.time() + while retry_active: + try: + if absolute_timeout < 0: + raise TimeoutError("Operation timed out.") + return operation(*args, timeout=absolute_timeout, **kwargs) + except AMQPException as exc: + if not self._retry_policy.is_retryable(exc): + raise + if absolute_timeout >= 0: + retry_active = self._retry_policy.increment(retry_settings, exc) + if not retry_active: + break + time.sleep(self._retry_policy.get_backoff_time(retry_settings, exc)) + if exc.condition == ErrorCondition.LinkDetachForced: + self._close_link() # if link level error, close and open a new link + if exc.condition in ( + ErrorCondition.ConnectionCloseForced, + ErrorCondition.SocketError, + ): + # if connection detach or socket error, close and open a new connection + self.close() + finally: + end_time = time.time() + if absolute_timeout > 0: + absolute_timeout -= end_time - start_time + raise retry_settings["history"][-1] + + def open(self, connection=None): + """Open the client. The client can create a new Connection + or an existing Connection can be passed in. This existing Connection + may have an existing CBS authentication Session, which will be + used for this client as well. Otherwise a new Session will be + created. + + :param connection: An existing Connection that may be shared between + multiple clients. + :type connection: ~pyamqp.Connection + """ + # pylint: disable=protected-access + if self._session: + return # already open. + if connection: + self._connection = connection + self._external_connection = True + elif not self._connection: + self._connection = Connection( + "amqps://" + self._hostname, + sasl_credential=self._auth.sasl, + ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + container_id=self._name, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout, + properties=self._properties, + network_trace=self._network_trace, + transport_type=self._transport_type, + http_proxy=self._http_proxy, + custom_endpoint_address=self._custom_endpoint_address, + ) + self._connection.open() + if not self._session: + self._session = self._connection.create_session( + incoming_window=self._incoming_window, + outgoing_window=self._outgoing_window, + ) + self._session.begin() + if self._keep_alive_interval: + self._keep_alive_thread = threading.Thread(target=self._keep_alive) + self._keep_alive_thread.daemon = True + self._keep_alive_thread.start() + if self._auth.auth_type == AUTH_TYPE_CBS: + self._cbs_authenticator = CBSAuthenticator( + session=self._session, auth=self._auth, auth_timeout=self._auth_timeout + ) + self._cbs_authenticator.open() + self._network_trace_params["amqpConnection"] = self._connection._container_id + self._network_trace_params["amqpSession"] = self._session.name + self._shutdown = False + + def close(self): + """Close the client. This includes closing the Session + and CBS authentication layer as well as the Connection. + If the client was opened using an external Connection, + this will be left intact. + + No further messages can be sent or received and the client + cannot be re-opened. + + All pending, unsent messages will remain uncleared to allow + them to be inspected and queued to a new client. + """ + self._shutdown = True + if not self._session: + return # already closed. + self._close_link() + if self._cbs_authenticator: + self._cbs_authenticator.close() + self._cbs_authenticator = None + self._session.end() + self._session = None + if not self._external_connection: + self._connection.close() + self._connection = None + if self._keep_alive_thread: + try: + self._keep_alive_thread.join() + except RuntimeError: # Probably thread failed to start in .open() + logging.debug("Keep alive thread failed to join.", exc_info=True) + self._keep_alive_thread = None + self._network_trace_params["amqpConnection"] = None + self._network_trace_params["amqpSession"] = None + + def auth_complete(self): + """Whether the authentication handshake is complete during + connection initialization. + + :rtype: bool + """ + if self._cbs_authenticator and not self._cbs_authenticator.handle_token(): + self._connection.listen(wait=self._socket_timeout) + return False + return True + + def client_ready(self): + """ + Whether the handler has completed all start up processes such as + establishing the connection, session, link and authentication, and + is not ready to process messages. + + :rtype: bool + """ + if not self.auth_complete(): + return False + if not self._client_ready(): + try: + self._connection.listen(wait=self._socket_timeout) + except ValueError: + return True + return False + return True + + def do_work(self, **kwargs): + """Run a single connection iteration. + This will return `True` if the connection is still open + and ready to be used for further work, or `False` if it needs + to be shut down. + + :rtype: bool + :raises: TimeoutError if CBS authentication timeout reached. + """ + if self._shutdown: + return False + if not self.client_ready(): + return True + return self._client_run(**kwargs) + + def mgmt_request(self, message, **kwargs): + """ + :param message: The message to send in the management request. + :type message: ~pyamqp.message.Message + :keyword str operation: The type of operation to be performed. This value will + be service-specific, but common values include READ, CREATE and UPDATE. + This value will be added as an application property on the message. + :keyword str operation_type: The type on which to carry out the operation. This will + be specific to the entities of the service. This value will be added as + an application property on the message. + :keyword str node: The target node. Default node is `$management`. + :keyword float timeout: Provide an optional timeout in seconds within which a response + to the management request must be received. + :rtype: ~pyamqp.message.Message + """ + + # The method also takes "status_code_field" and "status_description_field" + # keyword arguments as alternate names for the status code and description + # in the response body. Those two keyword arguments are used in Azure services only. + operation = kwargs.pop("operation", None) + operation_type = kwargs.pop("operation_type", None) + node = kwargs.pop("node", "$management") + timeout = kwargs.pop("timeout", 0) + with self._mgmt_link_lock: + try: + mgmt_link = self._mgmt_links[node] + except KeyError: + mgmt_link = ManagementOperation(self._session, endpoint=node, **kwargs) + self._mgmt_links[node] = mgmt_link + mgmt_link.open() + + while not mgmt_link.ready(): + self._connection.listen(wait=False) + operation_type = operation_type or b"empty" + status, description, response = mgmt_link.execute( + message, operation=operation, operation_type=operation_type, timeout=timeout + ) + return status, description, response + + +class SendClient(AMQPClient): + """ + An AMQP client for sending messages. + :param target: The target AMQP service endpoint. This can either be the URI as + a string or a ~pyamqp.endpoint.Target object. + :type target: str, bytes or ~pyamqp.endpoint.Target + :keyword auth: Authentication for the connection. This should be one of the following: + - pyamqp.authentication.SASLAnonymous + - pyamqp.authentication.SASLPlain + - pyamqp.authentication.SASTokenAuth + - pyamqp.authentication.JWTTokenAuth + If no authentication is supplied, SASLAnnoymous will be used by default. + :paramtype auth: ~pyamqp.authentication + :keyword client_name: The name for the client, also known as the Container ID. + If no name is provided, a random GUID will be used. + :paramtype client_name: str or bytes + :keyword network_trace: Whether to turn on network trace logs. If `True`, trace logs + will be logged at INFO level. Default is `False`. + :paramtype network_trace: bool + :keyword retry_policy: A policy for parsing errors on link, connection and message + disposition to determine whether the error should be retryable. + :paramtype retry_policy: ~pyamqp.error.RetryPolicy + :keyword keep_alive_interval: If set, a thread will be started to keep the connection + alive during periods of user inactivity. The value will determine how long the + thread will sleep (in seconds) between pinging the connection. If 0 or None, no + thread will be started. + :paramtype keep_alive_interval: int + :keyword max_frame_size: Maximum AMQP frame size. Default is 63488 bytes. + :paramtype max_frame_size: int + :keyword channel_max: Maximum number of Session channels in the Connection. + :paramtype channel_max: int + :keyword idle_timeout: Timeout in seconds after which the Connection will close + if there is no further activity. + :paramtype idle_timeout: int + :keyword auth_timeout: Timeout in seconds for CBS authentication. Otherwise this value will be ignored. + Default value is 60s. + :paramtype auth_timeout: int + :keyword properties: Connection properties. + :paramtype properties: dict[str, any] + :keyword remote_idle_timeout_empty_frame_send_ratio: Portion of the idle timeout time to wait before sending an + empty frame. The default portion is 50% of the idle timeout value (i.e. `0.5`). + :paramtype remote_idle_timeout_empty_frame_send_ratio: float + :keyword incoming_window: The size of the allowed window for incoming messages. + :paramtype incoming_window: int + :keyword outgoing_window: The size of the allowed window for outgoing messages. + :paramtype outgoing_window: int + :keyword handle_max: The maximum number of concurrent link handles. + :paramtype handle_max: int + :keyword on_attach: A callback function to be run on receipt of an ATTACH frame. + The function must take 4 arguments: source, target, properties and error. + :paramtype on_attach: func[ + ~pyamqp.endpoint.Source, ~pyamqp.endpoint.Target, dict, ~pyamqp.error.AMQPConnectionError] + :keyword send_settle_mode: The mode by which to settle message send + operations. If set to `Unsettled`, the client will wait for a confirmation + from the service that the message was successfully sent. If set to 'Settled', + the client will not wait for confirmation and assume success. + :paramtype send_settle_mode: ~pyamqp.constants.SenderSettleMode + :keyword receive_settle_mode: The mode by which to settle message receive + operations. If set to `PeekLock`, the receiver will lock a message once received until + the client accepts or rejects the message. If set to `ReceiveAndDelete`, the service + will assume successful receipt of the message and clear it from the queue. The + default is `PeekLock`. + :paramtype receive_settle_mode: ~pyamqp.constants.ReceiverSettleMode + :keyword desired_capabilities: The extension capabilities desired from the peer endpoint. + :paramtype desired_capabilities: list[bytes] + :keyword max_message_size: The maximum allowed message size negotiated for the Link. + :paramtype max_message_size: int + :keyword link_properties: Metadata to be sent in the Link ATTACH frame. + :paramtype link_properties: dict[str, any] + :keyword link_credit: The Link credit that determines how many + messages the Link will attempt to handle per connection iteration. + The default is 300. + :paramtype link_credit: int + :keyword transport_type: The type of transport protocol that will be used for communicating with + the service. Default is `TransportType.Amqp` in which case port 5671 is used. + If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could + be used instead which uses port 443 for communication. + :paramtype transport_type: ~pyamqp.constants.TransportType + :keyword http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. + :paramtype http_proxy: dict[str, str] + :keyword custom_endpoint_address: The custom endpoint address to use for establishing a connection to + the service, allowing network requests to be routed through any application gateways or + other paths needed for the host environment. Default is None. + If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. + :paramtype custom_endpoint_address: str + :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to + authenticate the identity of the connection endpoint. + Default is None in which case `certifi.where()` will be used. + :paramtype connection_verify: str + """ + + def __init__(self, hostname, target, **kwargs): + self.target = target + # Sender and Link settings + self._max_message_size = kwargs.pop("max_message_size", MAX_FRAME_SIZE_BYTES) + self._link_properties = kwargs.pop("link_properties", None) + self._link_credit = kwargs.pop("link_credit", None) + super(SendClient, self).__init__(hostname, **kwargs) + + def _client_ready(self): + """Determine whether the client is ready to start receiving messages. + To be ready, the connection must be open and authentication complete, + The Session, Link and MessageReceiver must be open and in non-errored + states. + + :rtype: bool + """ + # pylint: disable=protected-access + if not self._link: + self._link = self._session.create_sender_link( + target_address=self.target, + link_credit=self._link_credit, + send_settle_mode=self._send_settle_mode, + rcv_settle_mode=self._receive_settle_mode, + max_message_size=self._max_message_size, + properties=self._link_properties, + ) + self._link.attach() + return False + if self._link.get_state().value != 3: # ATTACHED + return False + return True + + def _client_run(self, **kwargs): + """MessageSender Link is now open - perform message send + on all pending messages. + Will return True if operation successful and client can remain open for + further work. + + :rtype: bool + """ + self._link.update_pending_deliveries() + self._connection.listen(wait=self._socket_timeout, **kwargs) + return True + + def _transfer_message(self, message_delivery, timeout=0): + message_delivery.state = MessageDeliveryState.WaitingForSendAck + on_send_complete = partial(self._on_send_complete, message_delivery) + delivery = self._link.send_transfer( + message_delivery.message, + on_send_complete=on_send_complete, + timeout=timeout, + send_async=True, + ) + return delivery + + @staticmethod + def _process_send_error(message_delivery, condition, description=None, info=None): + try: + amqp_condition = ErrorCondition(condition) + except ValueError: + error = MessageException(condition, description=description, info=info) + else: + error = MessageSendFailed( + amqp_condition, description=description, info=info + ) + message_delivery.state = MessageDeliveryState.Error + message_delivery.error = error + + def _on_send_complete(self, message_delivery, reason, state): + message_delivery.reason = reason + if reason == LinkDeliverySettleReason.DISPOSITION_RECEIVED: + if state and SEND_DISPOSITION_ACCEPT in state: + message_delivery.state = MessageDeliveryState.Ok + else: + try: + error_info = state[SEND_DISPOSITION_REJECT] + self._process_send_error( + message_delivery, + condition=error_info[0][0], + description=error_info[0][1], + info=error_info[0][2], + ) + except TypeError: + self._process_send_error( + message_delivery, condition=ErrorCondition.UnknownError + ) + elif reason == LinkDeliverySettleReason.SETTLED: + message_delivery.state = MessageDeliveryState.Ok + elif reason == LinkDeliverySettleReason.TIMEOUT: + message_delivery.state = MessageDeliveryState.Timeout + message_delivery.error = TimeoutError("Sending message timed out.") + else: + # NotDelivered and other unknown errors + self._process_send_error( + message_delivery, condition=ErrorCondition.UnknownError + ) + + def _send_message_impl(self, message, **kwargs): + timeout = kwargs.pop("timeout", 0) + expire_time = (time.time() + timeout) if timeout else None + self.open() + message_delivery = _MessageDelivery( + message, MessageDeliveryState.WaitingToBeSent, expire_time + ) + while not self.client_ready(): + time.sleep(0.05) + + self._transfer_message(message_delivery, timeout) + running = True + while running and message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: + running = self.do_work() + if message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: + raise MessageException( + condition=ErrorCondition.ClientError, + description="Send failed - connection not running." + ) + + if message_delivery.state in ( + MessageDeliveryState.Error, + MessageDeliveryState.Cancelled, + MessageDeliveryState.Timeout, + ): + try: + raise message_delivery.error # pylint: disable=raising-bad-type + except TypeError: + # This is a default handler + raise MessageException( + condition=ErrorCondition.UnknownError, description="Send failed." + ) + + def send_message(self, message, **kwargs): + """ + :param ~pyamqp.message.Message message: + :keyword float timeout: timeout in seconds. If set to + 0, the client will continue to wait until the message is sent or error happens. The + default is 0. + """ + self._do_retryable_operation(self._send_message_impl, message=message, **kwargs) + + +class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes + """ + An AMQP client for receiving messages. + :param source: The source AMQP service endpoint. This can either be the URI as + a string or a ~pyamqp.endpoint.Source object. + :type source: str, bytes or ~pyamqp.endpoint.Source + :keyword auth: Authentication for the connection. This should be one of the following: + - pyamqp.authentication.SASLAnonymous + - pyamqp.authentication.SASLPlain + - pyamqp.authentication.SASTokenAuth + - pyamqp.authentication.JWTTokenAuth + If no authentication is supplied, SASLAnnoymous will be used by default. + :paramtype auth: ~pyamqp.authentication + :keyword client_name: The name for the client, also known as the Container ID. + If no name is provided, a random GUID will be used. + :paramtype client_name: str or bytes + :keyword network_trace: Whether to turn on network trace logs. If `True`, trace logs + will be logged at INFO level. Default is `False`. + :paramtype network_trace: bool + :keyword retry_policy: A policy for parsing errors on link, connection and message + disposition to determine whether the error should be retryable. + :paramtype retry_policy: ~pyamqp.error.RetryPolicy + :keyword keep_alive_interval: If set, a thread will be started to keep the connection + alive during periods of user inactivity. The value will determine how long the + thread will sleep (in seconds) between pinging the connection. If 0 or None, no + thread will be started. + :paramtype keep_alive_interval: int + :keyword max_frame_size: Maximum AMQP frame size. Default is 63488 bytes. + :paramtype max_frame_size: int + :keyword channel_max: Maximum number of Session channels in the Connection. + :paramtype channel_max: int + :keyword idle_timeout: Timeout in seconds after which the Connection will close + if there is no further activity. + :paramtype idle_timeout: int + :keyword auth_timeout: Timeout in seconds for CBS authentication. Otherwise this value will be ignored. + Default value is 60s. + :paramtype auth_timeout: int + :keyword properties: Connection properties. + :paramtype properties: dict[str, any] + :keyword remote_idle_timeout_empty_frame_send_ratio: Portion of the idle timeout time to wait before sending an + empty frame. The default portion is 50% of the idle timeout value (i.e. `0.5`). + :paramtype remote_idle_timeout_empty_frame_send_ratio: float + :keyword incoming_window: The size of the allowed window for incoming messages. + :paramtype incoming_window: int + :keyword outgoing_window: The size of the allowed window for outgoing messages. + :paramtype outgoing_window: int + :keyword handle_max: The maximum number of concurrent link handles. + :paramtype handle_max: int + :keyword on_attach: A callback function to be run on receipt of an ATTACH frame. + The function must take 4 arguments: source, target, properties and error. + :paramtype on_attach: func[ + ~pyamqp.endpoint.Source, ~pyamqp.endpoint.Target, dict, ~pyamqp.error.AMQPConnectionError] + :keyword send_settle_mode: The mode by which to settle message send + operations. If set to `Unsettled`, the client will wait for a confirmation + from the service that the message was successfully sent. If set to 'Settled', + the client will not wait for confirmation and assume success. + :paramtype send_settle_mode: ~pyamqp.constants.SenderSettleMode + :keyword receive_settle_mode: The mode by which to settle message receive + operations. If set to `PeekLock`, the receiver will lock a message once received until + the client accepts or rejects the message. If set to `ReceiveAndDelete`, the service + will assume successful receipt of the message and clear it from the queue. The + default is `PeekLock`. + :paramtype receive_settle_mode: ~pyamqp.constants.ReceiverSettleMode + :keyword desired_capabilities: The extension capabilities desired from the peer endpoint. + :paramtype desired_capabilities: list[bytes] + :keyword max_message_size: The maximum allowed message size negotiated for the Link. + :paramtype max_message_size: int + :keyword link_properties: Metadata to be sent in the Link ATTACH frame. + :paramtype link_properties: dict[str, any] + :keyword link_credit: The Link credit that determines how many + messages the Link will attempt to handle per connection iteration. + The default is 300. + :paramtype link_credit: int + :keyword transport_type: The type of transport protocol that will be used for communicating with + the service. Default is `TransportType.Amqp` in which case port 5671 is used. + If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could + be used instead which uses port 443 for communication. + :paramtype transport_type: ~pyamqp.constants.TransportType + :keyword http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. + :paramtype http_proxy: dict[str, str] + :keyword custom_endpoint_address: The custom endpoint address to use for establishing a connection to + the service, allowing network requests to be routed through any application gateways or + other paths needed for the host environment. Default is None. + If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. + :paramtype custom_endpoint_address: str + :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to + authenticate the identity of the connection endpoint. + Default is None in which case `certifi.where()` will be used. + :paramtype connection_verify: str + """ + + def __init__(self, hostname, source, **kwargs): + self.source = source + self._streaming_receive = kwargs.pop("streaming_receive", False) + self._received_messages = queue.Queue() + self._message_received_callback = kwargs.pop("message_received_callback", None) + + # Sender and Link settings + self._max_message_size = kwargs.pop("max_message_size", MAX_FRAME_SIZE_BYTES) + self._link_properties = kwargs.pop("link_properties", None) + self._link_credit = kwargs.pop("link_credit", 300) + + # Iterator + self._timeout = kwargs.pop("timeout", 0) + self._timeout_reached = False + self._last_activity_timestamp = time.time() + + super(ReceiveClient, self).__init__(hostname, **kwargs) + + def _client_ready(self): + """Determine whether the client is ready to start receiving messages. + To be ready, the connection must be open and authentication complete, + The Session, Link and MessageReceiver must be open and in non-errored + states. + + :rtype: bool + """ + # pylint: disable=protected-access + if not self._link: + self._link = self._session.create_receiver_link( + source_address=self.source, + link_credit=self._link_credit, + send_settle_mode=self._send_settle_mode, + rcv_settle_mode=self._receive_settle_mode, + max_message_size=self._max_message_size, + on_transfer=self._message_received, + properties=self._link_properties, + desired_capabilities=self._desired_capabilities, + on_attach=self._on_attach, + ) + self._link.attach() + return False + if self._link.get_state().value != 3: # ATTACHED + return False + return True + + def _client_run(self, **kwargs): + """MessageReceiver Link is now open - start receiving messages. + Will return True if operation successful and client can remain open for + further work. + + :rtype: bool + """ + try: + if self._link.current_link_credit == 0: + self._link.flow() + self._connection.listen(wait=self._socket_timeout, **kwargs) + except ValueError: + _logger.info("Timeout reached, closing receiver.", extra=self._network_trace_params) + self._shutdown = True + return False + return True + + def _message_received(self, frame, message): + """Callback run on receipt of every message. If there is + a user-defined callback, this will be called. + Additionally if the client is retrieving messages for a batch + or iterator, the message will be added to an internal queue. + + :param message: Received message. + :type message: ~pyamqp.message.Message + """ + self._last_activity_timestamp = time.time() + if self._message_received_callback: + self._message_received_callback(message) + if not self._streaming_receive: + self._received_messages.put((frame, message)) + + def _receive_message_batch_impl( + self, max_batch_size=None, on_message_received=None, timeout=0 + ): + self._message_received_callback = on_message_received + max_batch_size = max_batch_size or self._link_credit + timeout = time.time() + timeout if timeout else 0 + receiving = True + batch = [] + self.open() + while len(batch) < max_batch_size: + try: + # TODO: This drops the transfer frame data + _, message = self._received_messages.get_nowait() + batch.append(message) + self._received_messages.task_done() + except queue.Empty: + break + else: + return batch + + to_receive_size = max_batch_size - len(batch) + before_queue_size = self._received_messages.qsize() + + while receiving and to_receive_size > 0: + if timeout and time.time() > timeout: + break + + receiving = self.do_work(batch=to_receive_size) + cur_queue_size = self._received_messages.qsize() + # after do_work, check how many new messages have been received since previous iteration + received = cur_queue_size - before_queue_size + if to_receive_size < max_batch_size and received == 0: + # there are already messages in the batch, and no message is received in the current cycle + # return what we have + break + + to_receive_size -= received + before_queue_size = cur_queue_size + + while len(batch) < max_batch_size: + try: + _, message = self._received_messages.get_nowait() + batch.append(message) + self._received_messages.task_done() + except queue.Empty: + break + return batch + + def close(self): + self._received_messages = queue.Queue() + super(ReceiveClient, self).close() + + def receive_message_batch(self, **kwargs): + """Receive a batch of messages. Messages returned in the batch have already been + accepted - if you wish to add logic to accept or reject messages based on custom + criteria, pass in a callback. This method will return as soon as some messages are + available rather than waiting to achieve a specific batch size, and therefore the + number of messages returned per call will vary up to the maximum allowed. + + :param max_batch_size: The maximum number of messages that can be returned in + one call. This value cannot be larger than the prefetch value, and if not specified, + the prefetch value will be used. + :type max_batch_size: int + :param on_message_received: A callback to process messages as they arrive from the + service. It takes a single argument, a ~pyamqp.message.Message object. + :type on_message_received: callable[~pyamqp.message.Message] + :param timeout: The timeout in milliseconds for which to wait to receive any messages. + If no messages are received in this time, an empty list will be returned. If set to + 0, the client will continue to wait until at least one message is received. The + default is 0. + :type timeout: float + """ + return self._do_retryable_operation(self._receive_message_batch_impl, **kwargs) + + def receive_messages_iter(self, timeout=None, on_message_received=None): + """Receive messages by generator. Messages returned in the generator have already been + accepted - if you wish to add logic to accept or reject messages based on custom + criteria, pass in a callback. + + :param on_message_received: A callback to process messages as they arrive from the + service. It takes a single argument, a ~pyamqp.message.Message object. + :type on_message_received: callable[~pyamqp.message.Message] + """ + self._message_received_callback = on_message_received + return self._message_generator(timeout=timeout) + + def _message_generator(self, timeout=None): + """Iterate over processed messages in the receive queue. + + :rtype: generator[~pyamqp.message.Message] + """ + self.open() + self._timeout_reached = False + receiving = True + message = None + self._last_activity_timestamp = time.time() + self._timeout = timeout if timeout else self._timeout + try: + while receiving and not self._timeout_reached: + if self._timeout > 0: + if time.time() - self._last_activity_timestamp >= self._timeout: + self._timeout_reached = True + + if not self._timeout_reached: + receiving = self.do_work() + + while not self._received_messages.empty(): + message = self._received_messages.get() + self._last_activity_timestamp = time.time() + self._received_messages.task_done() + yield message + + finally: + if self._shutdown: + self.close() + + @overload + def settle_messages( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["accepted"], + *, + batchable: Optional[bool] = None + ): + ... + + @overload + def settle_messages( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["released"], + *, + batchable: Optional[bool] = None + ): + ... + + @overload + def settle_messages( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["rejected"], + *, + error: Optional[AMQPError] = None, + batchable: Optional[bool] = None + ): + ... + + @overload + def settle_messages( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["modified"], + *, + delivery_failed: Optional[bool] = None, + undeliverable_here: Optional[bool] = None, + message_annotations: Optional[Dict[Union[str, bytes], Any]] = None, + batchable: Optional[bool] = None + ): + ... + + @overload + def settle_messages( + self, + delivery_id: Union[int, Tuple[int, int]], + outcome: Literal["received"], + *, + section_number: int, + section_offset: int, + batchable: Optional[bool] = None + ): + ... + + def settle_messages( + self, delivery_id: Union[int, Tuple[int, int]], outcome: str, **kwargs + ): + batchable = kwargs.pop("batchable", None) + if outcome.lower() == "accepted": + state: Outcomes = Accepted() + elif outcome.lower() == "released": + state = Released() + elif outcome.lower() == "rejected": + state = Rejected(**kwargs) + elif outcome.lower() == "modified": + state = Modified(**kwargs) + elif outcome.lower() == "received": + state = Received(**kwargs) + else: + raise ValueError("Unrecognized message output: {}".format(outcome)) + try: + first, last = cast(Tuple, delivery_id) + except TypeError: + first = delivery_id + last = None + self._link.send_disposition( + first_delivery_id=first, + last_delivery_id=last, + settled=True, + delivery_state=state, + batchable=batchable, + wait=True, + ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/constants.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/constants.py new file mode 100644 index 0000000000000..2acd4bc568827 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/constants.py @@ -0,0 +1,341 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- +from typing import cast +from collections import namedtuple +from enum import Enum +import struct + +_AS_BYTES = struct.Struct('>B') + +#: The IANA assigned port number for AMQP.The standard AMQP port number that has been assigned by IANA +#: for TCP, UDP, and SCTP.There are currently no UDP or SCTP mappings defined for AMQP. +#: The port number is reserved for future transport mappings to these protocols. +PORT = 5672 + +# default port for AMQP over Websocket +WEBSOCKET_PORT = 443 + +# subprotocol for AMQP over Websocket +AMQP_WS_SUBPROTOCOL = 'AMQPWSB10' + +#: The IANA assigned port number for secure AMQP (amqps).The standard AMQP port number that has been assigned +#: by IANA for secure TCP using TLS. Implementations listening on this port should NOT expect a protocol +#: handshake before TLS is negotiated. +SECURE_PORT = 5671 + + +# default port for AMQP over Websocket +WEBSOCKET_PORT = 443 + + +# subprotocol for AMQP over Websocket +AMQP_WS_SUBPROTOCOL = 'AMQPWSB10' + + +MAJOR = 1 #: Major protocol version. +MINOR = 0 #: Minor protocol version. +REV = 0 #: Protocol revision. +HEADER_FRAME = b"AMQP\x00" + _AS_BYTES.pack(MAJOR) + _AS_BYTES.pack(MINOR) + _AS_BYTES.pack(REV) + + +TLS_MAJOR = 1 #: Major protocol version. +TLS_MINOR = 0 #: Minor protocol version. +TLS_REV = 0 #: Protocol revision. +TLS_HEADER_FRAME = b"AMQP\x02" + _AS_BYTES.pack(TLS_MAJOR) + _AS_BYTES.pack(TLS_MINOR) + _AS_BYTES.pack(TLS_REV) + +SASL_MAJOR = 1 #: Major protocol version. +SASL_MINOR = 0 #: Minor protocol version. +SASL_REV = 0 #: Protocol revision. +SASL_HEADER_FRAME = b"AMQP\x03" + _AS_BYTES.pack(SASL_MAJOR) + _AS_BYTES.pack(SASL_MINOR) + _AS_BYTES.pack(SASL_REV) + +EMPTY_FRAME = b'\x00\x00\x00\x08\x02\x00\x00\x00' + +#: The lower bound for the agreed maximum frame size (in bytes). During the initial Connection negotiation, the +#: two peers must agree upon a maximum frame size. This constant defines the minimum value to which the maximum +#: frame size can be set. By defining this value, the peers can guarantee that they can send frames of up to this +#: size until they have agreed a definitive maximum frame size for that Connection. +MIN_MAX_FRAME_SIZE = 512 +MAX_FRAME_SIZE_BYTES = 1024 * 1024 +MAX_CHANNELS = 65535 +INCOMING_WINDOW = 64 * 1024 +OUTGOING_WINDOW = 64 * 1024 + +DEFAULT_LINK_CREDIT = 10000 + +FIELD = namedtuple('FIELD', 'name, type, mandatory, default, multiple') + +STRING_FILTER = b"apache.org:selector-filter:string" + +DEFAULT_AUTH_TIMEOUT = 60 +AUTH_DEFAULT_EXPIRATION_SECONDS = 3600 +TOKEN_TYPE_JWT = "jwt" +TOKEN_TYPE_SASTOKEN = "servicebus.windows.net:sastoken" +CBS_PUT_TOKEN = "put-token" +CBS_NAME = "name" +CBS_OPERATION = "operation" +CBS_TYPE = "type" +CBS_EXPIRATION = "expiration" + +SEND_DISPOSITION_ACCEPT = "accepted" +SEND_DISPOSITION_REJECT = "rejected" + +AUTH_TYPE_SASL_PLAIN = "AUTH_SASL_PLAIN" +AUTH_TYPE_CBS = "AUTH_CBS" + +DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS = 10 +READ_TIMEOUT_INTERVAL = 0.2 +TIMEOUT_INTERVAL = 1 +WS_TIMEOUT_INTERVAL = 1 + + +class ConnectionState(Enum): + #: In this state a Connection exists, but nothing has been sent or received. This is the state an + #: implementation would be in immediately after performing a socket connect or socket accept. + START = 0 + #: In this state the Connection header has been received from our peer, but we have not yet sent anything. + HDR_RCVD = 1 + #: In this state the Connection header has been sent to our peer, but we have not yet received anything. + HDR_SENT = 2 + #: In this state we have sent and received the Connection header, but we have not yet sent or + #: received an open frame. + HDR_EXCH = 3 + #: In this state we have sent both the Connection header and the open frame, but + #: we have not yet received anything. + OPEN_PIPE = 4 + #: In this state we have sent the Connection header, the open frame, any pipelined Connection traffic, + #: and the close frame, but we have not yet received anything. + OC_PIPE = 5 + #: In this state we have sent and received the Connection header, and received an open frame from + #: our peer, but have not yet sent an open frame. + OPEN_RCVD = 6 + #: In this state we have sent and received the Connection header, and sent an open frame to our peer, + #: but have not yet received an open frame. + OPEN_SENT = 7 + #: In this state we have send and received the Connection header, sent an open frame, any pipelined + #: Connection traffic, and the close frame, but we have not yet received an open frame. + CLOSE_PIPE = 8 + #: In this state the Connection header and the open frame have both been sent and received. + OPENED = 9 + #: In this state we have received a close frame indicating that our partner has initiated a close. + #: This means we will never have to read anything more from this Connection, however we can + #: continue to write frames onto the Connection. If desired, an implementation could do a TCP half-close + #: at this point to shutdown the read side of the Connection. + CLOSE_RCVD = 10 + #: In this state we have sent a close frame to our partner. It is illegal to write anything more onto + #: the Connection, however there may still be incoming frames. If desired, an implementation could do + #: a TCP half-close at this point to shutdown the write side of the Connection. + CLOSE_SENT = 11 + #: The DISCARDING state is a variant of the CLOSE_SENT state where the close is triggered by an error. + #: In this case any incoming frames on the connection MUST be silently discarded until the peer's close + #: frame is received. + DISCARDING = 12 + #: In this state it is illegal for either endpoint to write anything more onto the Connection. The + #: Connection may be safely closed and discarded. + END = 13 + + +class SessionState(Enum): + #: In the UNMAPPED state, the Session endpoint is not mapped to any incoming or outgoing channels on the + #: Connection endpoint. In this state an endpoint cannot send or receive frames. + UNMAPPED = 0 + #: In the BEGIN_SENT state, the Session endpoint is assigned an outgoing channel number, but there is no entry + #: in the incoming channel map. In this state the endpoint may send frames but cannot receive them. + BEGIN_SENT = 1 + #: In the BEGIN_RCVD state, the Session endpoint has an entry in the incoming channel map, but has not yet + #: been assigned an outgoing channel number. The endpoint may receive frames, but cannot send them. + BEGIN_RCVD = 2 + #: In the MAPPED state, the Session endpoint has both an outgoing channel number and an entry in the incoming + #: channel map. The endpoint may both send and receive frames. + MAPPED = 3 + #: In the END_SENT state, the Session endpoint has an entry in the incoming channel map, but is no longer + #: assigned an outgoing channel number. The endpoint may receive frames, but cannot send them. + END_SENT = 4 + #: In the END_RCVD state, the Session endpoint is assigned an outgoing channel number, but there is no entry in + #: the incoming channel map. The endpoint may send frames, but cannot receive them. + END_RCVD = 5 + #: The DISCARDING state is a variant of the END_SENT state where the end is triggered by an error. In this + #: case any incoming frames on the session MUST be silently discarded until the peer's end frame is received. + DISCARDING = 6 + + +class SessionTransferState(Enum): + + OKAY = 0 + ERROR = 1 + BUSY = 2 + + +class LinkDeliverySettleReason(Enum): + + DISPOSITION_RECEIVED = 0 + SETTLED = 1 + NOT_DELIVERED = 2 + TIMEOUT = 3 + CANCELLED = 4 + + +class LinkState(Enum): + + DETACHED = 0 + ATTACH_SENT = 1 + ATTACH_RCVD = 2 + ATTACHED = 3 + DETACH_SENT = 4 + DETACH_RCVD = 5 + ERROR = 6 + + +class ManagementLinkState(Enum): + + IDLE = 0 + OPENING = 1 + CLOSING = 2 + OPEN = 3 + ERROR = 4 + + +class ManagementOpenResult(Enum): + + OPENING = 0 + OK = 1 + ERROR = 2 + CANCELLED = 3 + + +class ManagementExecuteOperationResult(Enum): + + OK = 0 + ERROR = 1 + FAILED_BAD_STATUS = 2 + LINK_CLOSED = 3 + + +class CbsState(Enum): + CLOSED = 0 + OPENING = 1 + OPEN = 2 + ERROR = 3 + + +class CbsAuthState(Enum): + OK = 0 + IDLE = 1 + IN_PROGRESS = 2 + TIMEOUT = 3 + REFRESH_REQUIRED = 4 + EXPIRED = 5 + ERROR = 6 # Put token rejected or complete but fail authentication + FAILURE = 7 # Fail to open cbs links + + +class Role(object): + """Link endpoint role. + + Valid Values: + - False: Sender + - True: Receiver + + + + + + """ + Sender = False + Receiver = True + + +class SenderSettleMode(object): + """Settlement policy for a Sender. + + Valid Values: + - 0: The Sender will send all deliveries initially unsettled to the Receiver. + - 1: The Sender will send all deliveries settled to the Receiver. + - 2: The Sender may send a mixture of settled and unsettled deliveries to the Receiver. + + + + + + + """ + Unsettled = 0 + Settled = 1 + Mixed = 2 + + +class ReceiverSettleMode(object): + """Settlement policy for a Receiver. + + Valid Values: + - 0: The Receiver will spontaneously settle all incoming transfers. + - 1: The Receiver will only settle after sending the disposition to the Sender and + receiving a disposition indicating settlement of the delivery from the sender. + + + + + + """ + First = 0 + Second = 1 + + +class SASLCode(object): + """Codes to indicate the outcome of the sasl dialog. + + + + + + + + + """ + #: Connection authentication succeeded. + Ok = 0 + #: Connection authentication failed due to an unspecified problem with the supplied credentials. + Auth = 1 + #: Connection authentication failed due to a system error. + Sys = 2 + #: Connection authentication failed due to a system error that is unlikely to be corrected without intervention. + SysPerm = 3 + #: Connection authentication failed due to a transient system error. + SysTemp = 4 + + +class MessageDeliveryState(object): + + WaitingToBeSent = 0 + WaitingForSendAck = 1 + Ok = 2 + Error = 3 + Timeout = 4 + Cancelled = 5 + + +MESSAGE_DELIVERY_DONE_STATES = ( + MessageDeliveryState.Ok, + MessageDeliveryState.Error, + MessageDeliveryState.Timeout, + MessageDeliveryState.Cancelled +) + +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 + + def __eq__(self, __o: object) -> bool: + try: + __o = cast(Enum, __o) + return self.value == __o.value + except AttributeError: + return super().__eq__(__o) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/endpoints.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/endpoints.py new file mode 100644 index 0000000000000..bc8843bac37fd --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/endpoints.py @@ -0,0 +1,280 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +# The messaging layer defines two concrete types (source and target) to be used as the source and target of a +# link. These types are supplied in the source and target fields of the attach frame when establishing or +# resuming link. The source is comprised of an address (which the container of the outgoing Link Endpoint will +# resolve to a Node within that container) coupled with properties which determine: +# +# - which messages from the sending Node will be sent on the Link +# - how sending the message affects the state of that message at the sending Node +# - the behavior of Messages which have been transferred on the Link, but have not yet reached a +# terminal state at the receiver, when the source is destroyed. + +# TODO: fix mypy errors for _code/_definition/__defaults__ (issue #26500) +from collections import namedtuple + +from .types import AMQPTypes, FieldDefinition, ObjDefinition +from .constants import FIELD +from .performatives import _CAN_ADD_DOCSTRING + + +class TerminusDurability(object): + """Durability policy for a terminus. + + + + + + + + Determines which state of the terminus is held durably. + """ + #: No Terminus state is retained durably + NoDurability = 0 + #: Only the existence and configuration of the Terminus is retained durably. + Configuration = 1 + #: In addition to the existence and configuration of the Terminus, the unsettled state for durable + #: messages is retained durably. + UnsettledState = 2 + + +class ExpiryPolicy(object): + """Expiry policy for a terminus. + + + + + + + + + Determines when the expiry timer of a terminus starts counting down from the timeout + value. If the link is subsequently re-attached before the terminus is expired, then the + count down is aborted. If the conditions for the terminus-expiry-policy are subsequently + re-met, the expiry timer restarts from its originally configured timeout value. + """ + #: The expiry timer starts when Terminus is detached. + LinkDetach = b"link-detach" + #: The expiry timer starts when the most recently associated session is ended. + SessionEnd = b"session-end" + #: The expiry timer starts when most recently associated connection is closed. + ConnectionClose = b"connection-close" + #: The Terminus never expires. + Never = b"never" + + +class DistributionMode(object): + """Link distribution policy. + + + + + + + Policies for distributing messages when multiple links are connected to the same node. + """ + #: Once successfully transferred over the link, the message will no longer be available + #: to other links from the same node. + Move = b'move' + #: Once successfully transferred over the link, the message is still available for other + #: links from the same node. + Copy = b'copy' + + +class LifeTimePolicy(object): + #: Lifetime of dynamic node scoped to lifetime of link which caused creation. + #: A node dynamically created with this lifetime policy will be deleted at the point that the link + #: which caused its creation ceases to exist. + DeleteOnClose = 0x0000002b + #: Lifetime of dynamic node scoped to existence of links to the node. + #: A node dynamically created with this lifetime policy will be deleted at the point that there remain + #: no links for which the node is either the source or target. + DeleteOnNoLinks = 0x0000002c + #: Lifetime of dynamic node scoped to existence of messages on the node. + #: A node dynamically created with this lifetime policy will be deleted at the point that the link which + #: caused its creation no longer exists and there remain no messages at the node. + DeleteOnNoMessages = 0x0000002d + #: Lifetime of node scoped to existence of messages on or links to the node. + #: A node dynamically created with this lifetime policy will be deleted at the point that the there are no + #: links which have this node as their source or target, and there remain no messages at the node. + DeleteOnNoLinksOrMessages = 0x0000002e + + +class SupportedOutcomes(object): + #: Indicates successful processing at the receiver. + accepted = b"amqp:accepted:list" + #: Indicates an invalid and unprocessable message. + rejected = b"amqp:rejected:list" + #: Indicates that the message was not (and will not be) processed. + released = b"amqp:released:list" + #: Indicates that the message was modified, but not processed. + modified = b"amqp:modified:list" + + +class ApacheFilters(object): + #: Exact match on subject - analogous to legacy AMQP direct exchange bindings. + legacy_amqp_direct_binding = b"apache.org:legacy-amqp-direct-binding:string" + #: Pattern match on subject - analogous to legacy AMQP topic exchange bindings. + legacy_amqp_topic_binding = b"apache.org:legacy-amqp-topic-binding:string" + #: Matching on message headers - analogous to legacy AMQP headers exchange bindings. + legacy_amqp_headers_binding = b"apache.org:legacy-amqp-headers-binding:map" + #: Filter out messages sent from the same connection as the link is currently associated with. + no_local_filter = b"apache.org:no-local-filter:list" + #: SQL-based filtering syntax. + selector_filter = b"apache.org:selector-filter:string" + + +Source = namedtuple( + 'Source', + [ + 'address', + 'durable', + 'expiry_policy', + 'timeout', + 'dynamic', + 'dynamic_node_properties', + 'distribution_mode', + 'filters', + 'default_outcome', + 'outcomes', + 'capabilities' + ], + defaults=(None,) * 11 # type: ignore + ) +Source._code = 0x00000028 # type: ignore # pylint: disable=protected-access +Source._definition = ( # type: ignore # pylint: disable=protected-access + FIELD("address", AMQPTypes.string, False, None, False), + FIELD("durable", AMQPTypes.uint, False, "none", False), + FIELD("expiry_policy", AMQPTypes.symbol, False, ExpiryPolicy.SessionEnd, False), + FIELD("timeout", AMQPTypes.uint, False, 0, False), + FIELD("dynamic", AMQPTypes.boolean, False, False, False), + FIELD("dynamic_node_properties", FieldDefinition.node_properties, False, None, False), + FIELD("distribution_mode", AMQPTypes.symbol, False, None, False), + FIELD("filters", FieldDefinition.filter_set, False, None, False), + FIELD("default_outcome", ObjDefinition.delivery_state, False, None, False), + FIELD("outcomes", AMQPTypes.symbol, False, None, True), + FIELD("capabilities", AMQPTypes.symbol, False, None, True)) +if _CAN_ADD_DOCSTRING: + Source.__doc__ = """ + For containers which do not implement address resolution (and do not admit spontaneous link + attachment from their partners) but are instead only used as producers of messages, it is unnecessary to provide + spurious detail on the source. For this purpose it is possible to use a "minimal" source in which all the + fields are left unset. + + :param str address: The address of the source. + The address of the source MUST NOT be set when sent on a attach frame sent by the receiving Link Endpoint + where the dynamic fiag is set to true (that is where the receiver is requesting the sender to create an + addressable node). The address of the source MUST be set when sent on a attach frame sent by the sending + Link Endpoint where the dynamic fiag is set to true (that is where the sender has created an addressable + node at the request of the receiver and is now communicating the address of that created node). + The generated name of the address SHOULD include the link name and the container-id of the remote container + to allow for ease of identification. + :param ~uamqp.endpoints.TerminusDurability durable: Indicates the durability of the terminus. + Indicates what state of the terminus will be retained durably: the state of durable messages, only + existence and configuration of the terminus, or no state at all. + :param ~uamqp.endpoints.ExpiryPolicy expiry_policy: The expiry policy of the Source. + Determines when the expiry timer of a Terminus starts counting down from the timeout value. If the link + is subsequently re-attached before the Terminus is expired, then the count down is aborted. If the + conditions for the terminus-expiry-policy are subsequently re-met, the expiry timer restarts from its + originally configured timeout value. + :param int timeout: Duration that an expiring Source will be retained in seconds. + The Source starts expiring as indicated by the expiry-policy. + :param bool dynamic: Request dynamic creation of a remote Node. + When set to true by the receiving Link endpoint, this field constitutes a request for the sending peer + to dynamically create a Node at the source. In this case the address field MUST NOT be set. When set to + true by the sending Link Endpoint this field indicates creation of a dynamically created Node. In this case + the address field will contain the address of the created Node. The generated address SHOULD include the + Link name and Session-name or client-id in some recognizable form for ease of traceability. + :param dict dynamic_node_properties: Properties of the dynamically created Node. + If the dynamic field is not set to true this field must be left unset. When set by the receiving Link + endpoint, this field contains the desired properties of the Node the receiver wishes to be created. When + set by the sending Link endpoint this field contains the actual properties of the dynamically created node. + :param uamqp.endpoints.DistributionMode distribution_mode: The distribution mode of the Link. + This field MUST be set by the sending end of the Link if the endpoint supports more than one + distribution-mode. This field MAY be set by the receiving end of the Link to indicate a preference when a + Node supports multiple distribution modes. + :param dict filters: A set of predicates to filter the Messages admitted onto the Link. + The receiving endpoint sets its desired filter, the sending endpoint sets the filter actually in place + (including any filters defaulted at the node). The receiving endpoint MUST check that the filter in place + meets its needs and take responsibility for detaching if it does not. + Common filter types, along with the capabilities they are associated with are registered + here: http://www.amqp.org/specification/1.0/filters. + :param ~uamqp.outcomes.DeliveryState default_outcome: Default outcome for unsettled transfers. + Indicates the outcome to be used for transfers that have not reached a terminal state at the receiver + when the transfer is settled, including when the Source is destroyed. The value MUST be a valid + outcome (e.g. Released or Rejected). + :param list(bytes) outcomes: Descriptors for the outcomes that can be chosen on this link. + The values in this field are the symbolic descriptors of the outcomes that can be chosen on this link. + This field MAY be empty, indicating that the default-outcome will be assumed for all message transfers + (if the default-outcome is not set, and no outcomes are provided, then the accepted outcome must be + supported by the source). When present, the values MUST be a symbolic descriptor of a valid outcome, + e.g. "amqp:accepted:list". + :param list(bytes) capabilities: The extension capabilities the sender supports/desires. + See http://www.amqp.org/specification/1.0/source-capabilities. + """ + + +Target = namedtuple( + 'Target', + [ + 'address', + 'durable', + 'expiry_policy', + 'timeout', + 'dynamic', + 'dynamic_node_properties', + 'capabilities' + ], + defaults=(None,) * 7 # type: ignore + ) +Target._code = 0x00000029 # type: ignore # pylint: disable=protected-access +Target._definition = ( # type: ignore # pylint: disable=protected-access + FIELD("address", AMQPTypes.string, False, None, False), + FIELD("durable", AMQPTypes.uint, False, "none", False), + FIELD("expiry_policy", AMQPTypes.symbol, False, ExpiryPolicy.SessionEnd, False), + FIELD("timeout", AMQPTypes.uint, False, 0, False), + FIELD("dynamic", AMQPTypes.boolean, False, False, False), + FIELD("dynamic_node_properties", FieldDefinition.node_properties, False, None, False), + FIELD("capabilities", AMQPTypes.symbol, False, None, True)) +if _CAN_ADD_DOCSTRING: + Target.__doc__ = """ + For containers which do not implement address resolution (and do not admit spontaneous link attachment + from their partners) but are instead only used as consumers of messages, it is unnecessary to provide spurious + detail on the source. For this purpose it is possible to use a 'minimal' target in which all the + fields are left unset. + + :param str address: The address of the source. + The address of the source MUST NOT be set when sent on a attach frame sent by the receiving Link Endpoint + where the dynamic fiag is set to true (that is where the receiver is requesting the sender to create an + addressable node). The address of the source MUST be set when sent on a attach frame sent by the sending + Link Endpoint where the dynamic fiag is set to true (that is where the sender has created an addressable + node at the request of the receiver and is now communicating the address of that created node). + The generated name of the address SHOULD include the link name and the container-id of the remote container + to allow for ease of identification. + :param ~uamqp.endpoints.TerminusDurability durable: Indicates the durability of the terminus. + Indicates what state of the terminus will be retained durably: the state of durable messages, only + existence and configuration of the terminus, or no state at all. + :param ~uamqp.endpoints.ExpiryPolicy expiry_policy: The expiry policy of the Source. + Determines when the expiry timer of a Terminus starts counting down from the timeout value. If the link + is subsequently re-attached before the Terminus is expired, then the count down is aborted. If the + conditions for the terminus-expiry-policy are subsequently re-met, the expiry timer restarts from its + originally configured timeout value. + :param int timeout: Duration that an expiring Source will be retained in seconds. + The Source starts expiring as indicated by the expiry-policy. + :param bool dynamic: Request dynamic creation of a remote Node. + When set to true by the receiving Link endpoint, this field constitutes a request for the sending peer + to dynamically create a Node at the source. In this case the address field MUST NOT be set. When set to + true by the sending Link Endpoint this field indicates creation of a dynamically created Node. In this case + the address field will contain the address of the created Node. The generated address SHOULD include the + Link name and Session-name or client-id in some recognizable form for ease of traceability. + :param dict dynamic_node_properties: Properties of the dynamically created Node. + If the dynamic field is not set to true this field must be left unset. When set by the receiving Link + endpoint, this field contains the desired properties of the Node the receiver wishes to be created. When + set by the sending Link endpoint this field contains the actual properties of the dynamically created node. + :param list(bytes) capabilities: The extension capabilities the sender supports/desires. + See http://www.amqp.org/specification/1.0/source-capabilities. + """ diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/error.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/error.py new file mode 100644 index 0000000000000..91f3393eb8bf0 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/error.py @@ -0,0 +1,356 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +# TODO: fix mypy errors for _code/_definition/__defaults__ (issue #26500) +from enum import Enum +from collections import namedtuple + +from .constants import SECURE_PORT, FIELD +from .types import AMQPTypes, FieldDefinition + + +class ErrorCondition(bytes, Enum): + # Shared error conditions: + + #: An internal error occurred. Operator intervention may be required to resume normaloperation. + InternalError = b"amqp:internal-error" + #: A peer attempted to work with a remote entity that does not exist. + NotFound = b"amqp:not-found" + #: A peer attempted to work with a remote entity to which it has no access due tosecurity settings. + UnauthorizedAccess = b"amqp:unauthorized-access" + #: Data could not be decoded. + DecodeError = b"amqp:decode-error" + #: A peer exceeded its resource allocation. + ResourceLimitExceeded = b"amqp:resource-limit-exceeded" + #: The peer tried to use a frame in a manner that is inconsistent with the semantics defined in the specification. + NotAllowed = b"amqp:not-allowed" + #: An invalid field was passed in a frame body, and the operation could not proceed. + InvalidField = b"amqp:invalid-field" + #: The peer tried to use functionality that is not implemented in its partner. + NotImplemented = b"amqp:not-implemented" + #: The client attempted to work with a server entity to which it has no access + #: because another client is working with it. + ResourceLocked = b"amqp:resource-locked" + #: The client made a request that was not allowed because some precondition failed. + PreconditionFailed = b"amqp:precondition-failed" + #: A server entity the client is working with has been deleted. + ResourceDeleted = b"amqp:resource-deleted" + #: The peer sent a frame that is not permitted in the current state of the Session. + IllegalState = b"amqp:illegal-state" + #: The peer cannot send a frame because the smallest encoding of the performative with the currently + #: valid values would be too large to fit within a frame of the agreed maximum frame size. + FrameSizeTooSmall = b"amqp:frame-size-too-small" + + # Symbols used to indicate connection error conditions: + + #: An operator intervened to close the Connection for some reason. The client may retry at some later date. + ConnectionCloseForced = b"amqp:connection:forced" + #: A valid frame header cannot be formed from the incoming byte stream. + ConnectionFramingError = b"amqp:connection:framing-error" + #: The container is no longer available on the current connection. The peer should attempt reconnection + #: to the container using the details provided in the info map. + ConnectionRedirect = b"amqp:connection:redirect" + + # Symbols used to indicate session error conditions: + + #: The peer violated incoming window for the session. + SessionWindowViolation = b"amqp:session:window-violation" + #: Input was received for a link that was detached with an error. + SessionErrantLink = b"amqp:session:errant-link" + #: An attach was received using a handle that is already in use for an attached Link. + SessionHandleInUse = b"amqp:session:handle-in-use" + #: A frame (other than attach) was received referencing a handle which + #: is not currently in use of an attached Link. + SessionUnattachedHandle = b"amqp:session:unattached-handle" + + # Symbols used to indicate link error conditions: + + #: An operator intervened to detach for some reason. + LinkDetachForced = b"amqp:link:detach-forced" + #: The peer sent more Message transfers than currently allowed on the link. + LinkTransferLimitExceeded = b"amqp:link:transfer-limit-exceeded" + #: The peer sent a larger message than is supported on the link. + LinkMessageSizeExceeded = b"amqp:link:message-size-exceeded" + #: The address provided cannot be resolved to a terminus at the current container. + LinkRedirect = b"amqp:link:redirect" + #: The link has been attached elsewhere, causing the existing attachment to be forcibly closed. + LinkStolen = b"amqp:link:stolen" + + # Customized symbols used to indicate client error conditions. + # TODO: check whether Client/Unknown/Vendor Error are exposed in EH/SB as users might be depending + # on the code for error handling + ClientError = b"amqp:client-error" + UnknownError = b"amqp:unknown-error" + VendorError = b"amqp:vendor-error" + SocketError = b"amqp:socket-error" + + +class RetryMode(str, Enum): # pylint: disable=enum-must-inherit-case-insensitive-enum-meta + EXPONENTIAL = 'exponential' + FIXED = 'fixed' + + +class RetryPolicy: + + no_retry = [ + ErrorCondition.DecodeError, + ErrorCondition.LinkMessageSizeExceeded, + ErrorCondition.NotFound, + ErrorCondition.NotImplemented, + ErrorCondition.LinkRedirect, + ErrorCondition.NotAllowed, + ErrorCondition.UnauthorizedAccess, + ErrorCondition.LinkStolen, + ErrorCondition.ResourceLimitExceeded, + ErrorCondition.ConnectionRedirect, + ErrorCondition.PreconditionFailed, + ErrorCondition.InvalidField, + ErrorCondition.ResourceDeleted, + ErrorCondition.IllegalState, + ErrorCondition.FrameSizeTooSmall, + ErrorCondition.ConnectionFramingError, + ErrorCondition.SessionUnattachedHandle, + ErrorCondition.SessionHandleInUse, + ErrorCondition.SessionErrantLink, + ErrorCondition.SessionWindowViolation + ] + + def __init__( + self, + **kwargs + ): + """ + keyword int retry_total: + keyword float retry_backoff_factor: + keyword float retry_backoff_max: + keyword RetryMode retry_mode: + keyword list no_retry: + keyword dict custom_retry_policy: + """ + self.total_retries = kwargs.pop('retry_total', 3) + # TODO: A. consider letting retry_backoff_factor be either a float or a callback obj which returns a float + # to give more extensibility on customization of retry backoff time, the callback could take the exception + # as input. + self.backoff_factor = kwargs.pop('retry_backoff_factor', 0.8) + self.backoff_max = kwargs.pop('retry_backoff_max', 120) + self.retry_mode = kwargs.pop('retry_mode', RetryMode.EXPONENTIAL) + self.no_retry.extend(kwargs.get('no_retry', [])) + self.custom_condition_backoff = kwargs.pop("custom_condition_backoff", None) + # TODO: B. As an alternative of option A, we could have a new kwarg serve the goal + + def configure_retries(self, **kwargs): + return { + 'total': kwargs.pop("retry_total", self.total_retries), + 'backoff': kwargs.pop("retry_backoff_factor", self.backoff_factor), + 'max_backoff': kwargs.pop("retry_backoff_max", self.backoff_max), + 'retry_mode': kwargs.pop("retry_mode", self.retry_mode), + 'history': [] + } + + def increment(self, settings, error): # pylint: disable=no-self-use + settings['total'] -= 1 + settings['history'].append(error) + if settings['total'] < 0: + return False + return True + + def is_retryable(self, error): + try: + if error.condition in self.no_retry: + return False + except TypeError: + pass + return True + + def get_backoff_time(self, settings, error): + try: + return self.custom_condition_backoff[error.condition] + except (KeyError, TypeError): + pass + + consecutive_errors_len = len(settings['history']) + if consecutive_errors_len <= 1: + return 0 + + if self.retry_mode == RetryMode.FIXED: + backoff_value = settings['backoff'] + else: + backoff_value = settings['backoff'] * (2 ** (consecutive_errors_len - 1)) + return min(settings['max_backoff'], backoff_value) + + +AMQPError = namedtuple('AMQPError', ['condition', 'description', 'info'], defaults=[None, None]) +AMQPError.__new__.__defaults__ = (None,) * len(AMQPError._fields) # type: ignore +AMQPError._code = 0x0000001d # type: ignore # pylint: disable=protected-access +AMQPError._definition = ( # type: ignore # pylint: disable=protected-access + FIELD('condition', AMQPTypes.symbol, True, None, False), + FIELD('description', AMQPTypes.string, False, None, False), + FIELD('info', FieldDefinition.fields, False, None, False), +) + + +class AMQPException(Exception): + """Base exception for all errors. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + def __init__(self, condition, **kwargs): + self.condition = condition or ErrorCondition.UnknownError + self.description = kwargs.get("description", None) + self.info = kwargs.get("info", None) + self.message = kwargs.get("message", None) + self.inner_error = kwargs.get("error", None) + message = self.message or "Error condition: {}".format( + str(condition) if isinstance(condition, ErrorCondition) else condition.decode() + ) + if self.description: + try: + message += "\n Error Description: {}".format(self.description.decode()) + except (TypeError, AttributeError): + message += "\n Error Description: {}".format(self.description) + super(AMQPException, self).__init__(message) + + +class AMQPDecodeError(AMQPException): + """An error occurred while decoding an incoming frame. + + """ + + +class AMQPConnectionError(AMQPException): + """Details of a Connection-level error. + + """ + + +class AMQPConnectionRedirect(AMQPConnectionError): + """Details of a Connection-level redirect response. + + The container is no longer available on the current connection. + The peer should attempt reconnection to the container using the details provided. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + def __init__(self, condition, description=None, info=None): + self.hostname = info.get(b'hostname', b'').decode('utf-8') + self.network_host = info.get(b'network-host', b'').decode('utf-8') + self.port = int(info.get(b'port', SECURE_PORT)) + super(AMQPConnectionRedirect, self).__init__(condition, description=description, info=info) + + +class AMQPSessionError(AMQPException): + """Details of a Session-level error. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + + +class AMQPLinkError(AMQPException): + """Details of a Link-level error. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + + +class AMQPLinkRedirect(AMQPLinkError): + """Details of a Link-level redirect response. + + The address provided cannot be resolved to a terminus at the current container. + The supplied information may allow the client to locate and attach to the terminus. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + + def __init__(self, condition, description=None, info=None): + self.hostname = info.get(b'hostname', b'').decode('utf-8') + self.network_host = info.get(b'network-host', b'').decode('utf-8') + self.port = int(info.get(b'port', SECURE_PORT)) + self.address = info.get(b'address', b'').decode('utf-8') + super().__init__(condition, description=description, info=info) + + +class AuthenticationException(AMQPException): + """Details of a Authentication error. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + + +class TokenExpired(AuthenticationException): + """Details of a Token expiration error. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + + +class TokenAuthFailure(AuthenticationException): + """Failure to authenticate with token.""" + + def __init__(self, status_code, status_description, **kwargs): + encoding = kwargs.get("encoding", 'utf-8') + self.status_code = status_code + self.status_description = status_description + message = "CBS Token authentication failed.\nStatus code: {}".format(self.status_code) + if self.status_description: + try: + message += "\nDescription: {}".format(self.status_description.decode(encoding)) + except (TypeError, AttributeError): + message += "\nDescription: {}".format(self.status_description) + super(TokenAuthFailure, self).__init__(condition=ErrorCondition.ClientError, message=message) + + +class MessageException(AMQPException): + """Details of a Message error. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + + """ + + +class MessageSendFailed(MessageException): + """Details of a Message send failed error. + + :param bytes condition: The error code. + :keyword str description: A description of the error. + :keyword dict info: A dictionary of additional data associated with the error. + """ + + +class ErrorResponse(object): + """AMQP error object.""" + + def __init__(self, **kwargs): + self.condition = kwargs.get("condition") + self.description = kwargs.get("description") + + info = kwargs.get("info") + error_info = kwargs.get("error_info") + if isinstance(error_info, list) and len(error_info) >= 1: + if isinstance(error_info[0], list) and len(error_info[0]) >= 1: + self.condition = error_info[0][0] + if len(error_info[0]) >= 2: + self.description = error_info[0][1] + if len(error_info[0]) >= 3: + info = error_info[0][2] + + self.info = info + self.error = error_info diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/link.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/link.py new file mode 100644 index 0000000000000..9234d024f9f36 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/link.py @@ -0,0 +1,259 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + + +from typing import Optional +import uuid +import logging + +from .endpoints import Source, Target +from .constants import DEFAULT_LINK_CREDIT, SessionState, LinkState, Role, SenderSettleMode, ReceiverSettleMode +from .performatives import AttachFrame, DetachFrame + +from .error import ErrorCondition, AMQPLinkError, AMQPLinkRedirect, AMQPConnectionError + +_LOGGER = logging.getLogger(__name__) + + +class Link(object): # pylint: disable=too-many-instance-attributes + """An AMQP Link. + + This object should not be used directly - instead use one of directional + derivatives: Sender or Receiver. + """ + + def __init__(self, session, handle, name, role, **kwargs): + self.state = LinkState.DETACHED + self.name = name or str(uuid.uuid4()) + self.handle = handle + self.remote_handle = None + self.role = role + source_address = kwargs["source_address"] + target_address = kwargs["target_address"] + self.source = ( + source_address + if isinstance(source_address, Source) + else Source( + address=kwargs["source_address"], + durable=kwargs.get("source_durable"), + expiry_policy=kwargs.get("source_expiry_policy"), + timeout=kwargs.get("source_timeout"), + dynamic=kwargs.get("source_dynamic"), + dynamic_node_properties=kwargs.get("source_dynamic_node_properties"), + distribution_mode=kwargs.get("source_distribution_mode"), + filters=kwargs.get("source_filters"), + default_outcome=kwargs.get("source_default_outcome"), + outcomes=kwargs.get("source_outcomes"), + capabilities=kwargs.get("source_capabilities"), + ) + ) + self.target = ( + target_address + if isinstance(target_address, Target) + else Target( + address=kwargs["target_address"], + durable=kwargs.get("target_durable"), + expiry_policy=kwargs.get("target_expiry_policy"), + timeout=kwargs.get("target_timeout"), + dynamic=kwargs.get("target_dynamic"), + dynamic_node_properties=kwargs.get("target_dynamic_node_properties"), + capabilities=kwargs.get("target_capabilities"), + ) + ) + self.link_credit = kwargs.pop("link_credit", None) or DEFAULT_LINK_CREDIT + self.current_link_credit = self.link_credit + self.send_settle_mode = kwargs.pop("send_settle_mode", SenderSettleMode.Mixed) + self.rcv_settle_mode = kwargs.pop("rcv_settle_mode", ReceiverSettleMode.First) + self.unsettled = kwargs.pop("unsettled", None) + self.incomplete_unsettled = kwargs.pop("incomplete_unsettled", None) + self.initial_delivery_count = kwargs.pop("initial_delivery_count", 0) + self.delivery_count = self.initial_delivery_count + self.received_delivery_id = None + self.max_message_size = kwargs.pop("max_message_size", None) + self.remote_max_message_size = None + self.available = kwargs.pop("available", None) + self.properties = kwargs.pop("properties", None) + self.remote_properties = None + self.offered_capabilities = None + self.desired_capabilities = kwargs.pop("desired_capabilities", None) + + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["amqpLink"] = self.name + self._session = session + self._is_closed = False + self._on_link_state_change = kwargs.get("on_link_state_change") + self._on_attach = kwargs.get("on_attach") + self._error = None + + def __enter__(self): + self.attach() + return self + + def __exit__(self, *args): + self.detach(close=True) + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # TODO: Assuming we establish all links for now... + # check link_create_from_endpoint in C lib + raise NotImplementedError("Pending") + + def get_state(self): + try: + raise self._error + except TypeError: + pass + return self.state + + def _check_if_closed(self): + if self._is_closed: + try: + raise self._error + except TypeError: + raise AMQPConnectionError(condition=ErrorCondition.InternalError, description="Link already closed.") + + def _set_state(self, new_state): + # type: (LinkState) -> None + """Update the session state.""" + if new_state is None: + return + previous_state = self.state + self.state = new_state + _LOGGER.info("Link state changed: %r -> %r", previous_state, new_state, extra=self.network_trace_params) + try: + self._on_link_state_change(previous_state, new_state) + except TypeError: + pass + except Exception as e: # pylint: disable=broad-except + _LOGGER.error("Link state change callback failed: '%r'", e, extra=self.network_trace_params) + + def _on_session_state_change(self): + if self._session.state == SessionState.MAPPED: + if not self._is_closed and self.state == LinkState.DETACHED: + self._outgoing_attach() + self._set_state(LinkState.ATTACH_SENT) + elif self._session.state == SessionState.DISCARDING: + self._set_state(LinkState.DETACHED) + + def _outgoing_attach(self): + self.delivery_count = self.initial_delivery_count + attach_frame = AttachFrame( + name=self.name, + handle=self.handle, + role=self.role, + send_settle_mode=self.send_settle_mode, + rcv_settle_mode=self.rcv_settle_mode, + source=self.source, + target=self.target, + unsettled=self.unsettled, + incomplete_unsettled=self.incomplete_unsettled, + initial_delivery_count=self.initial_delivery_count if self.role == Role.Sender else None, + max_message_size=self.max_message_size, + offered_capabilities=self.offered_capabilities if self.state == LinkState.ATTACH_RCVD else None, + desired_capabilities=self.desired_capabilities if self.state == LinkState.DETACHED else None, + properties=self.properties, + ) + if self.network_trace: + _LOGGER.debug("-> %r", attach_frame, extra=self.network_trace_params) + self._session._outgoing_attach(attach_frame) # pylint: disable=protected-access + + def _incoming_attach(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", AttachFrame(*frame), extra=self.network_trace_params) + if self._is_closed: + raise ValueError("Invalid link") + if not frame[5] or not frame[6]: + _LOGGER.info("Cannot get source or target. Detaching link", extra=self.network_trace_params) + self._set_state(LinkState.DETACHED) + raise ValueError("Invalid link") + self.remote_handle = frame[1] # handle + self.remote_max_message_size = frame[10] # max_message_size + self.offered_capabilities = frame[11] # offered_capabilities + self.remote_properties = frame[13] + if self.state == LinkState.DETACHED: + self._set_state(LinkState.ATTACH_RCVD) + elif self.state == LinkState.ATTACH_SENT: + self._set_state(LinkState.ATTACHED) + if self._on_attach: + try: + if frame[5]: + frame[5] = Source(*frame[5]) + if frame[6]: + frame[6] = Target(*frame[6]) + self._on_attach(AttachFrame(*frame)) + except Exception as e: # pylint: disable=broad-except + _LOGGER.warning("Callback for link attach raised error: %r", e, extra=self.network_trace_params) + + def _outgoing_flow(self, **kwargs): + flow_frame = { + "handle": self.handle, + "delivery_count": self.delivery_count, + "link_credit": self.current_link_credit, + "available": kwargs.get("available"), + "drain": kwargs.get("drain"), + "echo": kwargs.get("echo"), + "properties": kwargs.get("properties"), + } + self._session._outgoing_flow(flow_frame) # pylint: disable=protected-access + + def _incoming_flow(self, frame): + pass + + def _incoming_disposition(self, frame): + pass + + def _outgoing_detach(self, close=False, error=None): + detach_frame = DetachFrame(handle=self.handle, closed=close, error=error) + if self.network_trace: + _LOGGER.debug("-> %r", detach_frame, extra=self.network_trace_params) + self._session._outgoing_detach(detach_frame) # pylint: disable=protected-access + if close: + self._is_closed = True + + def _incoming_detach(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", DetachFrame(*frame), extra=self.network_trace_params) + if self.state == LinkState.ATTACHED: + self._outgoing_detach(close=frame[1]) # closed + elif frame[1] and not self._is_closed and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: + # Received a closing detach after we sent a non-closing detach. + # In this case, we MUST signal that we closed by reattaching and then sending a closing detach. + self._outgoing_attach() + self._outgoing_detach(close=True) + # TODO: on_detach_hook + if frame[2]: # error + # frame[2][0] is condition, frame[2][1] is description, frame[2][2] is info + error_cls = AMQPLinkRedirect if frame[2][0] == ErrorCondition.LinkRedirect else AMQPLinkError + self._error = error_cls(condition=frame[2][0], description=frame[2][1], info=frame[2][2]) + self._set_state(LinkState.ERROR) + else: + self._set_state(LinkState.DETACHED) + + def attach(self): + if self._is_closed: + raise ValueError("Link already closed.") + self._outgoing_attach() + self._set_state(LinkState.ATTACH_SENT) + + def detach(self, close=False, error=None): + if self.state in (LinkState.DETACHED, LinkState.ERROR): + return + try: + self._check_if_closed() + if self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: + self._outgoing_detach(close=close, error=error) + self._set_state(LinkState.DETACHED) + elif self.state == LinkState.ATTACHED: + self._outgoing_detach(close=close, error=error) + self._set_state(LinkState.DETACH_SENT) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.info("An error occurred when detaching the link: %r", exc, extra=self.network_trace_params) + self._set_state(LinkState.DETACHED) + + def flow(self, *, link_credit: Optional[int] = None, **kwargs) -> None: + self.current_link_credit = link_credit if link_credit is not None else self.link_credit + self._outgoing_flow(**kwargs) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_link.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_link.py new file mode 100644 index 0000000000000..c5b1e6c0aa199 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_link.py @@ -0,0 +1,262 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +import time +import logging +from functools import partial +from collections import namedtuple + +from .sender import SenderLink +from .receiver import ReceiverLink +from .constants import ( + ManagementLinkState, + LinkState, + SenderSettleMode, + ReceiverSettleMode, + ManagementExecuteOperationResult, + ManagementOpenResult, + SEND_DISPOSITION_REJECT, + MessageDeliveryState, + LinkDeliverySettleReason +) +from .error import AMQPException, ErrorCondition +from .message import Properties, _MessageDelivery + +_LOGGER = logging.getLogger(__name__) + +PendingManagementOperation = namedtuple('PendingManagementOperation', ['message', 'on_execute_operation_complete']) + + +class ManagementLink(object): # pylint:disable=too-many-instance-attributes + """ + # TODO: Fill in docstring + """ + def __init__(self, session, endpoint, **kwargs): + self.next_message_id = 0 + self.state = ManagementLinkState.IDLE + self._pending_operations = [] + self._session = session + self._network_trace_params = kwargs.get('network_trace_params') + self._request_link: SenderLink = session.create_sender_link( + endpoint, + source_address=endpoint, + on_link_state_change=self._on_sender_state_change, + send_settle_mode=SenderSettleMode.Unsettled, + rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) + ) + self._response_link: ReceiverLink = session.create_receiver_link( + endpoint, + target_address=endpoint, + on_link_state_change=self._on_receiver_state_change, + on_transfer=self._on_message_received, + send_settle_mode=SenderSettleMode.Unsettled, + rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) + ) + self._on_amqp_management_error = kwargs.get('on_amqp_management_error') + self._on_amqp_management_open_complete = kwargs.get('on_amqp_management_open_complete') + + self._status_code_field = kwargs.get('status_code_field', b'statusCode') + self._status_description_field = kwargs.get('status_description_field', b'statusDescription') + + self._sender_connected = False + self._receiver_connected = False + + def __enter__(self): + self.open() + return self + + def __exit__(self, *args): + self.close() + + def _on_sender_state_change(self, previous_state, new_state): + _LOGGER.info( + "Management link sender state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) + if new_state == previous_state: + return + if self.state == ManagementLinkState.OPENING: + if new_state == LinkState.ATTACHED: + self._sender_connected = True + if self._receiver_connected: + self.state = ManagementLinkState.OPEN + self._on_amqp_management_open_complete(ManagementOpenResult.OK) + elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + self.state = ManagementLinkState.IDLE + self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) + elif self.state == ManagementLinkState.OPEN: + if new_state is not LinkState.ATTACHED: + self.state = ManagementLinkState.ERROR + self._on_amqp_management_error() + elif self.state == ManagementLinkState.CLOSING: + if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + self.state = ManagementLinkState.ERROR + self._on_amqp_management_error() + elif self.state == ManagementLinkState.ERROR: + # All state transitions shall be ignored. + return + + def _on_receiver_state_change(self, previous_state, new_state): + _LOGGER.info( + "Management link receiver state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) + if new_state == previous_state: + return + if self.state == ManagementLinkState.OPENING: + if new_state == LinkState.ATTACHED: + self._receiver_connected = True + if self._sender_connected: + self.state = ManagementLinkState.OPEN + self._on_amqp_management_open_complete(ManagementOpenResult.OK) + elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + self.state = ManagementLinkState.IDLE + self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) + elif self.state == ManagementLinkState.OPEN: + if new_state is not LinkState.ATTACHED: + self.state = ManagementLinkState.ERROR + self._on_amqp_management_error() + elif self.state == ManagementLinkState.CLOSING: + if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + self.state = ManagementLinkState.ERROR + self._on_amqp_management_error() + elif self.state == ManagementLinkState.ERROR: + # All state transitions shall be ignored. + return + + def _on_message_received(self, _, message): + message_properties = message.properties + correlation_id = message_properties[5] + response_detail = message.application_properties + + status_code = response_detail.get(self._status_code_field) + status_description = response_detail.get(self._status_description_field) + + to_remove_operation = None + for operation in self._pending_operations: + if operation.message.properties.message_id == correlation_id: + to_remove_operation = operation + break + if to_remove_operation: + mgmt_result = ManagementExecuteOperationResult.OK \ + if 200 <= status_code <= 299 else ManagementExecuteOperationResult.FAILED_BAD_STATUS + to_remove_operation.on_execute_operation_complete( + mgmt_result, + status_code, + status_description, + message, + response_detail.get(b'error-condition') + ) + self._pending_operations.remove(to_remove_operation) + + def _on_send_complete(self, message_delivery, reason, state): # todo: reason is never used, should check spec + if reason == LinkDeliverySettleReason.DISPOSITION_RECEIVED and SEND_DISPOSITION_REJECT in state: + # sample reject state: {'rejected': [[b'amqp:not-allowed', b"Invalid command 'RE1AD'.", None]]} + to_remove_operation = None + for operation in self._pending_operations: + if message_delivery.message == operation.message: + to_remove_operation = operation + break + self._pending_operations.remove(to_remove_operation) + # TODO: better error handling + # AMQPException is too general? to be more specific: MessageReject(Error) or AMQPManagementError? + # or should there an error mapping which maps the condition to the error type + to_remove_operation.on_execute_operation_complete( # The callback is defined in management_operation.py + ManagementExecuteOperationResult.ERROR, + None, + None, + message_delivery.message, + error=AMQPException( + condition=state[SEND_DISPOSITION_REJECT][0][0], # 0 is error condition + description=state[SEND_DISPOSITION_REJECT][0][1], # 1 is error description + info=state[SEND_DISPOSITION_REJECT][0][2], # 2 is error info + ) + ) + + def open(self): + if self.state != ManagementLinkState.IDLE: + raise ValueError("Management links are already open or opening.") + self.state = ManagementLinkState.OPENING + self._response_link.attach() + self._request_link.attach() + + def execute_operation( + self, + message, + on_execute_operation_complete, + **kwargs + ): + """Execute a request and wait on a response. + + :param message: The message to send in the management request. + :type message: ~uamqp.message.Message + :param on_execute_operation_complete: Callback to be called when the operation is complete. + The following value will be passed to the callback: operation_id, operation_result, status_code, + status_description, raw_message and error. + :type on_execute_operation_complete: Callable[[str, str, int, str, ~uamqp.message.Message, Exception], None] + :keyword operation: The type of operation to be performed. This value will + be service-specific, but common values include READ, CREATE and UPDATE. + This value will be added as an application property on the message. + :paramtype operation: bytes or str + :keyword type: The type on which to carry out the operation. This will + be specific to the entities of the service. This value will be added as + an application property on the message. + :paramtype type: bytes or str + :keyword str locales: A list of locales that the sending peer permits for incoming + informational text in response messages. + :keyword float timeout: Provide an optional timeout in seconds within which a response + to the management request must be received. + :rtype: None + """ + timeout = kwargs.get("timeout") + message.application_properties["operation"] = kwargs.get("operation") + message.application_properties["type"] = kwargs.get("type") + if "locales" in kwargs: + message.application_properties["locales"] = kwargs.get("locales") + try: + # TODO: namedtuple is immutable, which may push us to re-think about the namedtuple approach for Message + new_properties = message.properties._replace(message_id=self.next_message_id) + except AttributeError: + new_properties = Properties(message_id=self.next_message_id) + message = message._replace(properties=new_properties) + expire_time = (time.time() + timeout) if timeout else None + message_delivery = _MessageDelivery( + message, + MessageDeliveryState.WaitingToBeSent, + expire_time + ) + + on_send_complete = partial(self._on_send_complete, message_delivery) + + self._request_link.send_transfer( + message, + on_send_complete=on_send_complete, + timeout=timeout + ) + self.next_message_id += 1 + self._pending_operations.append(PendingManagementOperation(message, on_execute_operation_complete)) + + def close(self): + if self.state != ManagementLinkState.IDLE: + self.state = ManagementLinkState.CLOSING + self._response_link.detach(close=True) + self._request_link.detach(close=True) + for pending_operation in self._pending_operations: + pending_operation.on_execute_operation_complete( + ManagementExecuteOperationResult.LINK_CLOSED, + None, + None, + pending_operation.message, + AMQPException(condition=ErrorCondition.ClientError, description="Management link already closed.") + ) + self._pending_operations = [] + self.state = ManagementLinkState.IDLE diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_operation.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_operation.py new file mode 100644 index 0000000000000..475c3424a897f --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/management_operation.py @@ -0,0 +1,140 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- +import logging +import uuid +import time +from functools import partial + +from .management_link import ManagementLink +from .error import ( + AMQPLinkError, + ErrorCondition +) + +from .constants import ( + ManagementOpenResult, + ManagementExecuteOperationResult +) + +_LOGGER = logging.getLogger(__name__) + + +class ManagementOperation(object): + def __init__(self, session, endpoint='$management', **kwargs): + self._mgmt_link_open_status = None + + self._session = session + self._connection = self._session._connection + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } + self._mgmt_link = self._session.create_request_response_link_pair( + endpoint=endpoint, + on_amqp_management_open_complete=self._on_amqp_management_open_complete, + on_amqp_management_error=self._on_amqp_management_error, + **kwargs + ) # type: ManagementLink + self._responses = {} + self._mgmt_error = None + + def _on_amqp_management_open_complete(self, result): + """Callback run when the send/receive links are open and ready + to process messages. + + :param result: Whether the link opening was successful. + :type result: int + """ + self._mgmt_link_open_status = result + + def _on_amqp_management_error(self): + """Callback run if an error occurs in the send/receive links.""" + # TODO: This probably shouldn't be ValueError + self._mgmt_error = ValueError("Management Operation error occurred.") + + def _on_execute_operation_complete( + self, + operation_id, + operation_result, + status_code, + status_description, + raw_message, + error=None + ): + _LOGGER.debug( + "Management operation completed, id: %r; result: %r; code: %r; description: %r, error: %r", + operation_id, + operation_result, + status_code, + status_description, + error, + extra=self._network_trace_params + ) + + if operation_result in\ + (ManagementExecuteOperationResult.ERROR, ManagementExecuteOperationResult.LINK_CLOSED): + self._mgmt_error = error + _LOGGER.error( + "Failed to complete management operation due to error: %r.", + error, + extra=self._network_trace_params + ) + else: + self._responses[operation_id] = (status_code, status_description, raw_message) + + def execute(self, message, operation=None, operation_type=None, timeout=0): + start_time = time.time() + operation_id = str(uuid.uuid4()) + self._responses[operation_id] = None + self._mgmt_error = None + + self._mgmt_link.execute_operation( + message, + partial(self._on_execute_operation_complete, operation_id), + timeout=timeout, + operation=operation, + type=operation_type + ) + + while not self._responses[operation_id] and not self._mgmt_error: + if timeout and timeout > 0: + now = time.time() + if (now - start_time) >= timeout: + raise TimeoutError("Failed to receive mgmt response in {}ms".format(timeout)) + self._connection.listen() + + if self._mgmt_error: + self._responses.pop(operation_id) + raise self._mgmt_error # pylint: disable=raising-bad-type + + response = self._responses.pop(operation_id) + return response + + def open(self): + self._mgmt_link_open_status = ManagementOpenResult.OPENING + self._mgmt_link.open() + + def ready(self): + try: + raise self._mgmt_error # pylint: disable=raising-bad-type + except TypeError: + pass + + if self._mgmt_link_open_status == ManagementOpenResult.OPENING: + return False + if self._mgmt_link_open_status == ManagementOpenResult.OK: + return True + # ManagementOpenResult.ERROR or CANCELLED + # TODO: update below with correct status code + info + raise AMQPLinkError( + condition=ErrorCondition.ClientError, + description="Failed to open mgmt link, management link status: {}".format(self._mgmt_link_open_status), + info=None + ) + + def close(self): + self._mgmt_link.close() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/message.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/message.py new file mode 100644 index 0000000000000..233092ae184a8 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/message.py @@ -0,0 +1,272 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +# TODO: fix mypy errors for _code/_definition/__defaults__ (issue #26500) +from collections import namedtuple + +from .types import AMQPTypes, FieldDefinition +from .constants import FIELD, MessageDeliveryState +from .performatives import _CAN_ADD_DOCSTRING + + +Header = namedtuple( + 'Header', + [ + 'durable', + 'priority', + 'ttl', + 'first_acquirer', + 'delivery_count' + ], + defaults=(None,) * 5 # type: ignore + ) +Header._code = 0x00000070 # type: ignore # pylint:disable=protected-access +Header._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("durable", AMQPTypes.boolean, False, None, False), + FIELD("priority", AMQPTypes.ubyte, False, None, False), + FIELD("ttl", AMQPTypes.uint, False, None, False), + FIELD("first_acquirer", AMQPTypes.boolean, False, None, False), + FIELD("delivery_count", AMQPTypes.uint, False, None, False)) +if _CAN_ADD_DOCSTRING: + Header.__doc__ = """ + Transport headers for a Message. + + The header section carries standard delivery details about the transfer of a Message through the AMQP + network. If the header section is omitted the receiver MUST assume the appropriate default values for + the fields within the header unless other target or node specific defaults have otherwise been set. + + :param bool durable: Specify durability requirements. + Durable Messages MUST NOT be lost even if an intermediary is unexpectedly terminated and restarted. + A target which is not capable of fulfilling this guarantee MUST NOT accept messages where the durable + header is set to true: if the source allows the rejected outcome then the message should be rejected + with the precondition-failed error, otherwise the link must be detached by the receiver with the same error. + :param int priority: Relative Message priority. + This field contains the relative Message priority. Higher numbers indicate higher priority Messages. + Messages with higher priorities MAY be delivered before those with lower priorities. An AMQP intermediary + implementing distinct priority levels MUST do so in the following manner: + + - If n distince priorities are implemented and n is less than 10 - priorities 0 to (5 - ceiling(n/2)) + MUST be treated equivalently and MUST be the lowest effective priority. The priorities (4 + fioor(n/2)) + and above MUST be treated equivalently and MUST be the highest effective priority. The priorities + (5 ceiling(n/2)) to (4 + fioor(n/2)) inclusive MUST be treated as distinct priorities. + - If n distinct priorities are implemented and n is 10 or greater - priorities 0 to (n - 1) MUST be + distinct, and priorities n and above MUST be equivalent to priority (n - 1). Thus, for example, if 2 + distinct priorities are implemented, then levels 0 to 4 are equivalent, and levels 5 to 9 are equivalent + and levels 4 and 5 are distinct. If 3 distinct priorities are implements the 0 to 3 are equivalent, + 5 to 9 are equivalent and 3, 4 and 5 are distinct. This scheme ensures that if two priorities are distinct + for a server which implements m separate priority levels they are also distinct for a server which + implements n different priority levels where n > m. + + :param int ttl: Time to live in ms. + Duration in milliseconds for which the Message should be considered 'live'. If this is set then a message + expiration time will be computed based on the time of arrival at an intermediary. Messages that live longer + than their expiration time will be discarded (or dead lettered). When a message is transmitted by an + intermediary that was received with a ttl, the transmitted message's header should contain a ttl that is + computed as the difference between the current time and the formerly computed message expiration + time, i.e. the reduced ttl, so that messages will eventually die if they end up in a delivery loop. + :param bool first_acquirer: If this value is true, then this message has not been acquired by any other Link. + If this value is false, then this message may have previously been acquired by another Link or Links. + :param int delivery_count: The number of prior unsuccessful delivery attempts. + The number of unsuccessful previous attempts to deliver this message. If this value is non-zero it may + be taken as an indication that the delivery may be a duplicate. On first delivery, the value is zero. + It is incremented upon an outcome being settled at the sender, according to rules defined for each outcome. + """ + + +Properties = namedtuple( + 'Properties', + [ + 'message_id', + 'user_id', + 'to', + 'subject', + 'reply_to', + 'correlation_id', + 'content_type', + 'content_encoding', + 'absolute_expiry_time', + 'creation_time', + 'group_id', + 'group_sequence', + 'reply_to_group_id' + ], + defaults=(None,) * 13 # type: ignore + ) +Properties._code = 0x00000073 # type: ignore # pylint:disable=protected-access +Properties.__new__.__defaults__ = (None,) * len(Properties._fields) # type: ignore +Properties._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("message_id", FieldDefinition.message_id, False, None, False), + FIELD("user_id", AMQPTypes.binary, False, None, False), + FIELD("to", AMQPTypes.string, False, None, False), + FIELD("subject", AMQPTypes.string, False, None, False), + FIELD("reply_to", AMQPTypes.string, False, None, False), + FIELD("correlation_id", FieldDefinition.message_id, False, None, False), + FIELD("content_type", AMQPTypes.symbol, False, None, False), + FIELD("content_encoding", AMQPTypes.symbol, False, None, False), + FIELD("absolute_expiry_time", AMQPTypes.timestamp, False, None, False), + FIELD("creation_time", AMQPTypes.timestamp, False, None, False), + FIELD("group_id", AMQPTypes.string, False, None, False), + FIELD("group_sequence", AMQPTypes.uint, False, None, False), + FIELD("reply_to_group_id", AMQPTypes.string, False, None, False)) +if _CAN_ADD_DOCSTRING: + Properties.__doc__ = """ + Immutable properties of the Message. + + The properties section is used for a defined set of standard properties of the message. The properties + section is part of the bare message and thus must, if retransmitted by an intermediary, remain completely + unaltered. + + :param message_id: Application Message identifier. + Message-id is an optional property which uniquely identifies a Message within the Message system. + The Message producer is usually responsible for setting the message-id in such a way that it is assured + to be globally unique. A broker MAY discard a Message as a duplicate if the value of the message-id + matches that of a previously received Message sent to the same Node. + :param bytes user_id: Creating user id. + The identity of the user responsible for producing the Message. The client sets this value, and it MAY + be authenticated by intermediaries. + :param to: The address of the Node the Message is destined for. + The to field identifies the Node that is the intended destination of the Message. On any given transfer + this may not be the Node at the receiving end of the Link. + :param str subject: The subject of the message. + A common field for summary information about the Message content and purpose. + :param reply_to: The Node to send replies to. + The address of the Node to send replies to. + :param correlation_id: Application correlation identifier. + This is a client-specific id that may be used to mark or identify Messages between clients. + :param bytes content_type: MIME content type. + The RFC-2046 MIME type for the Message's application-data section (body). As per RFC-2046 this may contain + a charset parameter defining the character encoding used: e.g. 'text/plain; charset="utf-8"'. + For clarity, the correct MIME type for a truly opaque binary section is application/octet-stream. + When using an application-data section with a section code other than data, contenttype, if set, SHOULD + be set to a MIME type of message/x-amqp+?, where '?' is either data, map or list. + :param bytes content_encoding: MIME content type. + The Content-Encoding property is used as a modifier to the content-type. When present, its value indicates + what additional content encodings have been applied to the application-data, and thus what decoding + mechanisms must be applied in order to obtain the media-type referenced by the content-type header field. + Content-Encoding is primarily used to allow a document to be compressed without losing the identity of + its underlying content type. Content Encodings are to be interpreted as per Section 3.5 of RFC 2616. + Valid Content Encodings are registered at IANA as "Hypertext Transfer Protocol (HTTP) Parameters" + (http://www.iana.org/assignments/http-parameters/httpparameters.xml). Content-Encoding MUST not be set when + the application-data section is other than data. Implementations MUST NOT use the identity encoding. + Instead, implementations should not set this property. Implementations SHOULD NOT use the compress + encoding, except as to remain compatible with messages originally sent with other protocols, + e.g. HTTP or SMTP. Implementations SHOULD NOT specify multiple content encoding values except as to be + compatible with messages originally sent with other protocols, e.g. HTTP or SMTP. + :param datetime absolute_expiry_time: The time when this message is considered expired. + An absolute time when this message is considered to be expired. + :param datetime creation_time: The time when this message was created. + An absolute time when this message was created. + :param str group_id: The group this message belongs to. + Identifies the group the message belongs to. + :param int group_sequence: The sequence-no of this message within its group. + The relative position of this message within its group. + :param str reply_to_group_id: The group the reply message belongs to. + This is a client-specific id that is used so that client can send replies to this message to a specific group. + """ + +# TODO: should be a class, namedtuple or dataclass, immutability vs performance, need to collect performance data +Message = namedtuple( + 'Message', + [ + 'header', + 'delivery_annotations', + 'message_annotations', + 'properties', + 'application_properties', + 'data', + 'sequence', + 'value', + 'footer', + ], + defaults=(None,) * 9 # type: ignore + ) +Message._code = 0 # type: ignore # pylint:disable=protected-access +Message._definition = ( # type: ignore # pylint:disable=protected-access + (0x00000070, FIELD("header", Header, False, None, False)), + (0x00000071, FIELD("delivery_annotations", FieldDefinition.annotations, False, None, False)), + (0x00000072, FIELD("message_annotations", FieldDefinition.annotations, False, None, False)), + (0x00000073, FIELD("properties", Properties, False, None, False)), + (0x00000074, FIELD("application_properties", AMQPTypes.map, False, None, False)), + (0x00000075, FIELD("data", AMQPTypes.binary, False, None, True)), + (0x00000076, FIELD("sequence", AMQPTypes.list, False, None, False)), + (0x00000077, FIELD("value", None, False, None, False)), + (0x00000078, FIELD("footer", FieldDefinition.annotations, False, None, False))) +if _CAN_ADD_DOCSTRING: + Message.__doc__ = """ + An annotated message consists of the bare message plus sections for annotation at the head and tail + of the bare message. + + There are two classes of annotations: annotations that travel with the message indefinitely, and + annotations that are consumed by the next node. + The exact structure of a message, together with its encoding, is defined by the message format. This document + defines the structure and semantics of message format 0 (MESSAGE-FORMAT). Altogether a message consists of the + following sections: + + - Zero or one header. + - Zero or one delivery-annotations. + - Zero or one message-annotations. + - Zero or one properties. + - Zero or one application-properties. + - The body consists of either: one or more data sections, one or more amqp-sequence sections, + or a single amqp-value section. + - Zero or one footer. + + :param ~uamqp.message.Header header: Transport headers for a Message. + The header section carries standard delivery details about the transfer of a Message through the AMQP + network. If the header section is omitted the receiver MUST assume the appropriate default values for + the fields within the header unless other target or node specific defaults have otherwise been set. + :param dict delivery_annotations: The delivery-annotations section is used for delivery-specific non-standard + properties at the head of the message. Delivery annotations convey information from the sending peer to + the receiving peer. If the recipient does not understand the annotation it cannot be acted upon and its + effects (such as any implied propagation) cannot be acted upon. Annotations may be specific to one + implementation, or common to multiple implementations. The capabilities negotiated on link attach and on + the source and target should be used to establish which annotations a peer supports. A registry of defined + annotations and their meanings can be found here: http://www.amqp.org/specification/1.0/delivery-annotations. + If the delivery-annotations section is omitted, it is equivalent to a delivery-annotations section + containing an empty map of annotations. + :param dict message_annotations: The message-annotations section is used for properties of the message which + are aimed at the infrastructure and should be propagated across every delivery step. Message annotations + convey information about the message. Intermediaries MUST propagate the annotations unless the annotations + are explicitly augmented or modified (e.g. by the use of the modified outcome). + The capabilities negotiated on link attach and on the source and target may be used to establish which + annotations a peer understands, however it a network of AMQP intermediaries it may not be possible to know + if every intermediary will understand the annotation. Note that for some annotation it may not be necessary + for the intermediary to understand their purpose - they may be being used purely as an attribute which can be + filtered on. A registry of defined annotations and their meanings can be found here: + http://www.amqp.org/specification/1.0/message-annotations. If the message-annotations section is omitted, + it is equivalent to a message-annotations section containing an empty map of annotations. + :param ~uamqp.message.Properties: Immutable properties of the Message. + The properties section is used for a defined set of standard properties of the message. The properties + section is part of the bare message and thus must, if retransmitted by an intermediary, remain completely + unaltered. + :param dict application_properties: The application-properties section is a part of the bare message used + for structured application data. Intermediaries may use the data within this structure for the purposes + of filtering or routing. The keys of this map are restricted to be of type string (which excludes the + possibility of a null key) and the values are restricted to be of simple types only (that is excluding + map, list, and array types). + :param list(bytes) data_body: A data section contains opaque binary data. + :param list sequence_body: A sequence section contains an arbitrary number of structured data elements. + :param value_body: An amqp-value section contains a single AMQP value. + :param dict footer: Transport footers for a Message. + The footer section is used for details about the message or delivery which can only be calculated or + evaluated once the whole bare message has been constructed or seen (for example message hashes, HMACs, + signatures and encryption details). A registry of defined footers and their meanings can be found + here: http://www.amqp.org/specification/1.0/footer. + """ + + +class BatchMessage(Message): + _code = 0x80013700 + + +class _MessageDelivery: + def __init__(self, message, state=MessageDeliveryState.WaitingToBeSent, expiry=None): + self.message = message + self.state = state + self.expiry = expiry + self.reason = None + self.delivery = None + self.error = None diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/outcomes.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/outcomes.py new file mode 100644 index 0000000000000..64c5d09c7f661 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/outcomes.py @@ -0,0 +1,160 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +# The Messaging layer defines a concrete set of delivery states which can be used (via the disposition frame) +# to indicate the state of the message at the receiver. + +# Delivery states may be either terminal or non-terminal. Once a delivery reaches a terminal delivery-state, +# the state for that delivery will no longer change. A terminal delivery-state is referred to as an outcome. + +# The following outcomes are formally defined by the messaging layer to indicate the result of processing at the +# receiver: + +# - accepted: indicates successful processing at the receiver +# - rejected: indicates an invalid and unprocessable message +# - released: indicates that the message was not (and will not be) processed +# - modified: indicates that the message was modified, but not processed + +# The following non-terminal delivery-state is formally defined by the messaging layer for use during link +# recovery to allow the sender to resume the transfer of a large message without retransmitting all the +# message data: + +# - received: indicates partial message data seen by the receiver as well as the starting point for a +# resumed transfer + +# TODO: fix mypy errors for _code/_definition/__defaults__ (issue #26500) +from collections import namedtuple + +from .types import AMQPTypes, FieldDefinition, ObjDefinition +from .constants import FIELD +from .performatives import _CAN_ADD_DOCSTRING + + +Received = namedtuple('Received', ['section_number', 'section_offset']) +Received._code = 0x00000023 # type: ignore # pylint:disable=protected-access +Received._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("section_number", AMQPTypes.uint, True, None, False), + FIELD("section_offset", AMQPTypes.ulong, True, None, False)) +if _CAN_ADD_DOCSTRING: + Received.__doc__ = """ + At the target the received state indicates the furthest point in the payload of the message + which the target will not need to have resent if the link is resumed. At the source the received state represents + the earliest point in the payload which the Sender is able to resume transferring at in the case of link + resumption. When resuming a delivery, if this state is set on the first transfer performative it indicates + the offset in the payload at which the first resumed delivery is starting. The Sender MUST NOT send the + received state on transfer or disposition performatives except on the first transfer performative on a + resumed delivery. + + :param int section_number: + When sent by the Sender this indicates the first section of the message (with sectionnumber 0 being the + first section) for which data can be resent. Data from sections prior to the given section cannot be + retransmitted for this delivery. When sent by the Receiver this indicates the first section of the message + for which all data may not yet have been received. + :param int section_offset: + When sent by the Sender this indicates the first byte of the encoded section data of the section given by + section-number for which data can be resent (with section-offset 0 being the first byte). Bytes from the + same section prior to the given offset section cannot be retransmitted for this delivery. When sent by the + Receiver this indicates the first byte of the given section which has not yet been received. Note that if + a receiver has received all of section number X (which contains N bytes of data), but none of section + number X + 1, then it may indicate this by sending either Received(section-number=X, section-offset=N) or + Received(section-number=X+1, section-offset=0). The state Received(sectionnumber=0, section-offset=0) + indicates that no message data at all has been transferred. + """ + + +Accepted = namedtuple('Accepted', []) +Accepted._code = 0x00000024 # type: ignore # pylint:disable=protected-access +Accepted._definition = () # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + Accepted.__doc__ = """ + The accepted outcome. + + At the source the accepted state means that the message has been retired from the node, and transfer of + payload data will not be able to be resumed if the link becomes suspended. A delivery may become accepted at + the source even before all transfer frames have been sent, this does not imply that the remaining transfers + for the delivery will not be sent - only the aborted fiag on the transfer performative can be used to indicate + a premature termination of the transfer. At the target, the accepted outcome is used to indicate that an + incoming Message has been successfully processed, and that the receiver of the Message is expecting the sender + to transition the delivery to the accepted state at the source. The accepted outcome does not increment the + delivery-count in the header of the accepted Message. + """ + + +Rejected = namedtuple('Rejected', ['error']) +Rejected.__new__.__defaults__ = (None,) * len(Rejected._fields) # type: ignore +Rejected._code = 0x00000025 # type: ignore # pylint:disable=protected-access +Rejected._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + Rejected.__doc__ = """ + The rejected outcome. + + At the target, the rejected outcome is used to indicate that an incoming Message is invalid and therefore + unprocessable. The rejected outcome when applied to a Message will cause the delivery-count to be incremented + in the header of the rejected Message. At the source, the rejected outcome means that the target has informed + the source that the message was rejected, and the source has taken the required action. The delivery SHOULD + NOT ever spontaneously attain the rejected state at the source. + + :param ~uamqp.error.AMQPError error: The error that caused the message to be rejected. + The value supplied in this field will be placed in the delivery-annotations of the rejected Message + associated with the symbolic key "rejected". + """ + + +Released = namedtuple('Released', []) +Released._code = 0x00000026 # type: ignore # pylint:disable=protected-access +Released._definition = () # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + Released.__doc__ = """ + The released outcome. + + At the source the released outcome means that the message is no longer acquired by the receiver, and has been + made available for (re-)delivery to the same or other targets receiving from the node. The message is unchanged + at the node (i.e. the delivery-count of the header of the released Message MUST NOT be incremented). + As released is a terminal outcome, transfer of payload data will not be able to be resumed if the link becomes + suspended. A delivery may become released at the source even before all transfer frames have been sent, this + does not imply that the remaining transfers for the delivery will not be sent. The source MAY spontaneously + attain the released outcome for a Message (for example the source may implement some sort of time bound + acquisition lock, after which the acquisition of a message at a node is revoked to allow for delivery to an + alternative consumer). + + At the target, the released outcome is used to indicate that a given transfer was not and will not be acted upon. + """ + + +Modified = namedtuple('Modified', ['delivery_failed', 'undeliverable_here', 'message_annotations']) +Modified.__new__.__defaults__ = (None,) * len(Modified._fields) # type: ignore +Modified._code = 0x00000027 # type: ignore # pylint:disable=protected-access +Modified._definition = ( # type: ignore # pylint:disable=protected-access + FIELD('delivery_failed', AMQPTypes.boolean, False, None, False), + FIELD('undeliverable_here', AMQPTypes.boolean, False, None, False), + FIELD('message_annotations', FieldDefinition.fields, False, None, False)) +if _CAN_ADD_DOCSTRING: + Modified.__doc__ = """ + The modified outcome. + + At the source the modified outcome means that the message is no longer acquired by the receiver, and has been + made available for (re-)delivery to the same or other targets receiving from the node. The message has been + changed at the node in the ways indicated by the fields of the outcome. As modified is a terminal outcome, + transfer of payload data will not be able to be resumed if the link becomes suspended. A delivery may become + modified at the source even before all transfer frames have been sent, this does not imply that the remaining + transfers for the delivery will not be sent. The source MAY spontaneously attain the modified outcome for a + Message (for example the source may implement some sort of time bound acquisition lock, after which the + acquisition of a message at a node is revoked to allow for delivery to an alternative consumer with the + message modified in some way to denote the previous failed, e.g. with delivery-failed set to true). + At the target, the modified outcome is used to indicate that a given transfer was not and will not be acted + upon, and that the message should be modified in the specified ways at the node. + + :param bool delivery_failed: Count the transfer as an unsuccessful delivery attempt. + If the delivery-failed fiag is set, any Messages modified MUST have their deliverycount incremented. + :param bool undeliverable_here: Prevent redelivery. + If the undeliverable-here is set, then any Messages released MUST NOT be redelivered to the modifying + Link Endpoint. + :param dict message_annotations: Message attributes. + Map containing attributes to combine with the existing message-annotations held in the Message's header + section. Where the existing message-annotations of the Message contain an entry with the same key as an + entry in this field, the value in this field associated with that key replaces the one in the existing + headers; where the existing message-annotations has no such value, the value in this map is added. + """ diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/performatives.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/performatives.py new file mode 100644 index 0000000000000..efcfc444ccd7b --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/performatives.py @@ -0,0 +1,634 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +# TODO: fix mypy errors for _code/_definition/__defaults__ (issue #26500) +from collections import namedtuple +import sys + +from .types import AMQPTypes, FieldDefinition, ObjDefinition +from .constants import FIELD + +_CAN_ADD_DOCSTRING = sys.version_info.major >= 3 + + +OpenFrame = namedtuple( + 'OpenFrame', + [ + 'container_id', + 'hostname', + 'max_frame_size', + 'channel_max', + 'idle_timeout', + 'outgoing_locales', + 'incoming_locales', + 'offered_capabilities', + 'desired_capabilities', + 'properties' + ]) +OpenFrame._code = 0x00000010 # type: ignore # pylint:disable=protected-access +OpenFrame._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("container_id", AMQPTypes.string, True, None, False), + FIELD("hostname", AMQPTypes.string, False, None, False), + FIELD("max_frame_size", AMQPTypes.uint, False, 4294967295, False), + FIELD("channel_max", AMQPTypes.ushort, False, 65535, False), + FIELD("idle_timeout", AMQPTypes.uint, False, None, False), + FIELD("outgoing_locales", AMQPTypes.symbol, False, None, True), + FIELD("incoming_locales", AMQPTypes.symbol, False, None, True), + FIELD("offered_capabilities", AMQPTypes.symbol, False, None, True), + FIELD("desired_capabilities", AMQPTypes.symbol, False, None, True), + FIELD("properties", FieldDefinition.fields, False, None, False)) +if _CAN_ADD_DOCSTRING: + OpenFrame.__doc__ = """ + OPEN performative. Negotiate Connection parameters. + + The first frame sent on a connection in either direction MUST contain an Open body. + (Note that theConnection header which is sent first on the Connection is *not* a frame.) + The fields indicate thecapabilities and limitations of the sending peer. + + :param str container_id: The ID of the source container. + :param str hostname: The name of the target host. + The dns name of the host (either fully qualified or relative) to which the sendingpeer is connecting. + It is not mandatory to provide the hostname. If no hostname isprovided the receiving peer should select + a default based on its own configuration.This field can be used by AMQP proxies to determine the correct + back-end service toconnect the client to.This field may already have been specified by the sasl-init frame, + if a SASL layer is used, or, the server name indication extension as described in RFC-4366, if a TLSlayer + is used, in which case this field SHOULD be null or contain the same value. It is undefined what a different + value to those already specific means. + :param int max_frame_size: Proposed maximum frame size in bytes. + The largest frame size that the sending peer is able to accept on this Connection. + If this field is not set it means that the peer does not impose any specific limit. A peer MUST NOT send + frames larger than its partner can handle. A peer that receives an oversized frame MUST close the Connection + with the framing-error error-code. Both peers MUST accept frames of up to 512 (MIN-MAX-FRAME-SIZE) + octets large. + :param int channel_max: The maximum channel number that may be used on the Connection. + The channel-max value is the highest channel number that may be used on the Connection. This value plus one + is the maximum number of Sessions that can be simultaneously active on the Connection. A peer MUST not use + channel numbers outside the range that its partner can handle. A peer that receives a channel number + outside the supported range MUST close the Connection with the framing-error error-code. + :param int idle_timeout: Idle time-out in milliseconds. + The idle time-out required by the sender. A value of zero is the same as if it was not set (null). If the + receiver is unable or unwilling to support the idle time-out then it should close the connection with + an error explaining why (eg, because it is too small). If the value is not set, then the sender does not + have an idle time-out. However, senders doing this should be aware that implementations MAY choose to use + an internal default to efficiently manage a peer's resources. + :param list(str) outgoing_locales: Locales available for outgoing text. + A list of the locales that the peer supports for sending informational text. This includes Connection, + Session and Link error descriptions. A peer MUST support at least the en-US locale. Since this value + is always supported, it need not be supplied in the outgoing-locales. A null value or an empty list implies + that only en-US is supported. + :param list(str) incoming_locales: Desired locales for incoming text in decreasing level of preference. + A list of locales that the sending peer permits for incoming informational text. This list is ordered in + decreasing level of preference. The receiving partner will chose the first (most preferred) incoming locale + from those which it supports. If none of the requested locales are supported, en-US will be chosen. Note + that en-US need not be supplied in this list as it is always the fallback. A peer may determine which of the + permitted incoming locales is chosen by examining the partner's supported locales asspecified in the + outgoing_locales field. A null value or an empty list implies that only en-US is supported. + :param list(str) offered_capabilities: The extension capabilities the sender supports. + If the receiver of the offered-capabilities requires an extension capability which is not present in the + offered-capability list then it MUST close the connection. A list of commonly defined connection capabilities + and their meanings can be found here: http://www.amqp.org/specification/1.0/connection-capabilities. + :param list(str) required_capabilities: The extension capabilities the sender may use if the receiver supports + them. The desired-capability list defines which extension capabilities the sender MAY use if the receiver + offers them (i.e. they are in the offered-capabilities list received by the sender of the + desired-capabilities). If the receiver of the desired-capabilities offers extension capabilities which are + not present in the desired-capability list it received, then it can be sure those (undesired) capabilities + will not be used on the Connection. + :param dict properties: Connection properties. + The properties map contains a set of fields intended to indicate information about the connection and its + container. A list of commonly defined connection properties and their meanings can be found + here: http://www.amqp.org/specification/1.0/connection-properties. + """ + + +BeginFrame = namedtuple( + 'BeginFrame', + [ + 'remote_channel', + 'next_outgoing_id', + 'incoming_window', + 'outgoing_window', + 'handle_max', + 'offered_capabilities', + 'desired_capabilities', + 'properties' + ]) +BeginFrame._code = 0x00000011 # type: ignore # pylint:disable=protected-access +BeginFrame._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("remote_channel", AMQPTypes.ushort, False, None, False), + FIELD("next_outgoing_id", AMQPTypes.uint, True, None, False), + FIELD("incoming_window", AMQPTypes.uint, True, None, False), + FIELD("outgoing_window", AMQPTypes.uint, True, None, False), + FIELD("handle_max", AMQPTypes.uint, False, 4294967295, False), + FIELD("offered_capabilities", AMQPTypes.symbol, False, None, True), + FIELD("desired_capabilities", AMQPTypes.symbol, False, None, True), + FIELD("properties", FieldDefinition.fields, False, None, False)) +if _CAN_ADD_DOCSTRING: + BeginFrame.__doc__ = """ + BEGIN performative. Begin a Session on a channel. + + Indicate that a Session has begun on the channel. + + :param int remote_channel: The remote channel for this Session. + If a Session is locally initiated, the remote-channel MUST NOT be set. When an endpoint responds to a + remotely initiated Session, the remote-channel MUST be set to the channel on which the remote Session + sent the begin. + :param int next_outgoing_id: The transfer-id of the first transfer id the sender will send. + The next-outgoing-id is used to assign a unique transfer-id to all outgoing transfer frames on a given + session. The next-outgoing-id may be initialized to an arbitrary value and is incremented after each + successive transfer according to RFC-1982 serial number arithmetic. + :param int incoming_window: The initial incoming-window of the sender. + The incoming-window defines the maximum number of incoming transfer frames that the endpoint can currently + receive. This identifies a current maximum incoming transfer-id that can be computed by subtracting one + from the sum of incoming-window and next-incoming-id. + :param int outgoing_window: The initial outgoing-window of the sender. + The outgoing-window defines the maximum number of outgoing transfer frames that the endpoint can currently + send. This identifies a current maximum outgoing transfer-id that can be computed by subtracting one from + the sum of outgoing-window and next-outgoing-id. + :param int handle_max: The maximum handle value that may be used on the Session. + The handle-max value is the highest handle value that may be used on the Session. A peer MUST NOT attempt + to attach a Link using a handle value outside the range that its partner can handle. A peer that receives + a handle outside the supported range MUST close the Connection with the framing-error error-code. + :param list(str) offered_capabilities: The extension capabilities the sender supports. + A list of commonly defined session capabilities and their meanings can be found + here: http://www.amqp.org/specification/1.0/session-capabilities. + :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver + supports them. + :param dict properties: Session properties. + The properties map contains a set of fields intended to indicate information about the session and its + container. A list of commonly defined session properties and their meanings can be found + here: http://www.amqp.org/specification/1.0/session-properties. + """ + + +AttachFrame = namedtuple( + 'AttachFrame', + [ + 'name', + 'handle', + 'role', + 'send_settle_mode', + 'rcv_settle_mode', + 'source', + 'target', + 'unsettled', + 'incomplete_unsettled', + 'initial_delivery_count', + 'max_message_size', + 'offered_capabilities', + 'desired_capabilities', + 'properties' + ]) +AttachFrame._code = 0x00000012 # type: ignore # pylint:disable=protected-access +AttachFrame._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("name", AMQPTypes.string, True, None, False), + FIELD("handle", AMQPTypes.uint, True, None, False), + FIELD("role", AMQPTypes.boolean, True, None, False), + FIELD("send_settle_mode", AMQPTypes.ubyte, False, 2, False), + FIELD("rcv_settle_mode", AMQPTypes.ubyte, False, 0, False), + FIELD("source", ObjDefinition.source, False, None, False), + FIELD("target", ObjDefinition.target, False, None, False), + FIELD("unsettled", AMQPTypes.map, False, None, False), + FIELD("incomplete_unsettled", AMQPTypes.boolean, False, False, False), + FIELD("initial_delivery_count", AMQPTypes.uint, False, None, False), + FIELD("max_message_size", AMQPTypes.ulong, False, None, False), + FIELD("offered_capabilities", AMQPTypes.symbol, False, None, True), + FIELD("desired_capabilities", AMQPTypes.symbol, False, None, True), + FIELD("properties", FieldDefinition.fields, False, None, False)) +if _CAN_ADD_DOCSTRING: + AttachFrame.__doc__ = """ + ATTACH performative. Attach a Link to a Session. + + The attach frame indicates that a Link Endpoint has been attached to the Session. The opening flag + is used to indicate that the Link Endpoint is newly created. + + :param str name: The name of the link. + This name uniquely identifies the link from the container of the source to the container of the target + node, e.g. if the container of the source node is A, and the container of the target node is B, the link + may be globally identified by the (ordered) tuple(A,B,). + :param int handle: The handle of the link. + The handle MUST NOT be used for other open Links. An attempt to attach using a handle which is already + associated with a Link MUST be responded to with an immediate close carrying a Handle-in-usesession-error. + To make it easier to monitor AMQP link attach frames, it is recommended that implementations always assign + the lowest available handle to this field. + :param bool role: The role of the link endpoint. Either Role.Sender (False) or Role.Receiver (True). + :param str send_settle_mode: The settlement mode for the Sender. + Determines the settlement policy for deliveries sent at the Sender. When set at the Receiver this indicates + the desired value for the settlement mode at the Sender. When set at the Sender this indicates the actual + settlement mode in use. + :param str rcv_settle_mode: The settlement mode of the Receiver. + Determines the settlement policy for unsettled deliveries received at the Receiver. When set at the Sender + this indicates the desired value for the settlement mode at the Receiver. When set at the Receiver this + indicates the actual settlement mode in use. + :param ~uamqp.messaging.Source source: The source for Messages. + If no source is specified on an outgoing Link, then there is no source currently attached to the Link. + A Link with no source will never produce outgoing Messages. + :param ~uamqp.messaging.Target target: The target for Messages. + If no target is specified on an incoming Link, then there is no target currently attached to the Link. + A Link with no target will never permit incoming Messages. + :param dict unsettled: Unsettled delivery state. + This is used to indicate any unsettled delivery states when a suspended link is resumed. The map is keyed + by delivery-tag with values indicating the delivery state. The local and remote delivery states for a given + delivery-tag MUST be compared to resolve any in-doubt deliveries. If necessary, deliveries MAY be resent, + or resumed based on the outcome of this comparison. If the local unsettled map is too large to be encoded + within a frame of the agreed maximum frame size then the session may be ended with the + frame-size-too-smallerror. The endpoint SHOULD make use of the ability to send an incomplete unsettled map + to avoid sending an error. The unsettled map MUST NOT contain null valued keys. When reattaching + (as opposed to resuming), the unsettled map MUST be null. + :param bool incomplete_unsettled: + If set to true this field indicates that the unsettled map provided is not complete. When the map is + incomplete the recipient of the map cannot take the absence of a delivery tag from the map as evidence of + settlement. On receipt of an incomplete unsettled map a sending endpoint MUST NOT send any new deliveries + (i.e. deliveries where resume is not set to true) to its partner (and a receiving endpoint which sent an + incomplete unsettled map MUST detach with an error on receiving a transfer which does not have the resume + flag set to true). + :param int initial_delivery_count: This MUST NOT be null if role is sender, + and it is ignored if the role is receiver. + :param int max_message_size: The maximum message size supported by the link endpoint. + This field indicates the maximum message size supported by the link endpoint. Any attempt to deliver a + message larger than this results in a message-size-exceeded link-error. If this field is zero or unset, + there is no maximum size imposed by the link endpoint. + :param list(str) offered_capabilities: The extension capabilities the sender supports. + A list of commonly defined session capabilities and their meanings can be found + here: http://www.amqp.org/specification/1.0/link-capabilities. + :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver + supports them. + :param dict properties: Link properties. + The properties map contains a set of fields intended to indicate information about the link and its + container. A list of commonly defined link properties and their meanings can be found + here: http://www.amqp.org/specification/1.0/link-properties. + """ + + +FlowFrame = namedtuple( + 'FlowFrame', + [ + 'next_incoming_id', + 'incoming_window', + 'next_outgoing_id', + 'outgoing_window', + 'handle', + 'delivery_count', + 'link_credit', + 'available', + 'drain', + 'echo', + 'properties' + ]) +FlowFrame.__new__.__defaults__ = (None, None, None, None, None, None, None) # type: ignore +FlowFrame._code = 0x00000013 # type: ignore # pylint:disable=protected-access +FlowFrame._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("next_incoming_id", AMQPTypes.uint, False, None, False), + FIELD("incoming_window", AMQPTypes.uint, True, None, False), + FIELD("next_outgoing_id", AMQPTypes.uint, True, None, False), + FIELD("outgoing_window", AMQPTypes.uint, True, None, False), + FIELD("handle", AMQPTypes.uint, False, None, False), + FIELD("delivery_count", AMQPTypes.uint, False, None, False), + FIELD("link_credit", AMQPTypes.uint, False, None, False), + FIELD("available", AMQPTypes.uint, False, None, False), + FIELD("drain", AMQPTypes.boolean, False, False, False), + FIELD("echo", AMQPTypes.boolean, False, False, False), + FIELD("properties", FieldDefinition.fields, False, None, False)) +if _CAN_ADD_DOCSTRING: + FlowFrame.__doc__ = """ + FLOW performative. Update link state. + + Updates the flow state for the specified Link. + + :param int next_incoming_id: Identifies the expected transfer-id of the next incoming transfer frame. + This value is not set if and only if the sender has not yet received the begin frame for the session. + :param int incoming_window: Defines the maximum number of incoming transfer frames that the endpoint + concurrently receive. + :param int next_outgoing_id: The transfer-id that will be assigned to the next outgoing transfer frame. + :param int outgoing_window: Defines the maximum number of outgoing transfer frames that the endpoint could + potentially currently send, if it was not constrained by restrictions imposed by its peer's incoming-window. + :param int handle: If set, indicates that the flow frame carries flow state information for the local Link + Endpoint associated with the given handle. If not set, the flow frame is carrying only information + pertaining to the Session Endpoint. If set to a handle that is not currently associated with an attached + Link, the recipient MUST respond by ending the session with an unattached-handle session error. + :param int delivery_count: The endpoint's delivery-count. + When the handle field is not set, this field MUST NOT be set. When the handle identifies that the flow + state is being sent from the Sender Link Endpoint to Receiver Link Endpoint this field MUST be set to the + current delivery-count of the Link Endpoint. When the flow state is being sent from the Receiver Endpoint + to the Sender Endpoint this field MUST be set to the last known value of the corresponding Sending Endpoint. + In the event that the Receiving Link Endpoint has not yet seen the initial attach frame from the Sender + this field MUST NOT be set. + :param int link_credit: The current maximum number of Messages that can be received. + The current maximum number of Messages that can be handled at the Receiver Endpoint of the Link. Only the + receiver endpoint can independently set this value. The sender endpoint sets this to the last known + value seen from the receiver. When the handle field is not set, this field MUST NOT be set. + :param int available: The number of available Messages. + The number of Messages awaiting credit at the link sender endpoint. Only the sender can independently set + this value. The receiver sets this to the last known value seen from the sender. When the handle field is + not set, this field MUST NOT be set. + :param bool drain: Indicates drain mode. + When flow state is sent from the sender to the receiver, this field contains the actual drain mode of the + sender. When flow state is sent from the receiver to the sender, this field contains the desired drain + mode of the receiver. When the handle field is not set, this field MUST NOT be set. + :param bool echo: Request link state from other endpoint. + :param dict properties: Link state properties. + A list of commonly defined link state properties and their meanings can be found + here: http://www.amqp.org/specification/1.0/link-state-properties. + """ + + +TransferFrame = namedtuple( + 'TransferFrame', + [ + 'handle', + 'delivery_id', + 'delivery_tag', + 'message_format', + 'settled', + 'more', + 'rcv_settle_mode', + 'state', + 'resume', + 'aborted', + 'batchable', + 'payload' + ]) +TransferFrame._code = 0x00000014 # type: ignore # pylint:disable=protected-access +TransferFrame._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("handle", AMQPTypes.uint, True, None, False), + FIELD("delivery_id", AMQPTypes.uint, False, None, False), + FIELD("delivery_tag", AMQPTypes.binary, False, None, False), + FIELD("message_format", AMQPTypes.uint, False, 0, False), + FIELD("settled", AMQPTypes.boolean, False, None, False), + FIELD("more", AMQPTypes.boolean, False, False, False), + FIELD("rcv_settle_mode", AMQPTypes.ubyte, False, None, False), + FIELD("state", ObjDefinition.delivery_state, False, None, False), + FIELD("resume", AMQPTypes.boolean, False, False, False), + FIELD("aborted", AMQPTypes.boolean, False, False, False), + FIELD("batchable", AMQPTypes.boolean, False, False, False), + None) +if _CAN_ADD_DOCSTRING: + TransferFrame.__doc__ = """ + TRANSFER performative. Transfer a Message. + + The transfer frame is used to send Messages across a Link. Messages may be carried by a single transfer up + to the maximum negotiated frame size for the Connection. Larger Messages may be split across several + transfer frames. + + :param int handle: Specifies the Link on which the Message is transferred. + :param int delivery_id: Alias for delivery-tag. + The delivery-id MUST be supplied on the first transfer of a multi-transfer delivery. On continuation + transfers the delivery-id MAY be omitted. It is an error if the delivery-id on a continuation transfer + differs from the delivery-id on the first transfer of a delivery. + :param bytes delivery_tag: Uniquely identifies the delivery attempt for a given Message on this Link. + This field MUST be specified for the first transfer of a multi transfer message and may only be + omitted for continuation transfers. + :param int message_format: Indicates the message format. + This field MUST be specified for the first transfer of a multi transfer message and may only be omitted + for continuation transfers. + :param bool settled: If not set on the first (or only) transfer for a delivery, then the settled flag MUST + be interpreted as being false. For subsequent transfers if the settled flag is left unset then it MUST be + interpreted as true if and only if the value of the settled flag on any of the preceding transfers was + true; if no preceding transfer was sent with settled being true then the value when unset MUST be taken + as false. If the negotiated value for snd-settle-mode at attachment is settled, then this field MUST be + true on at least one transfer frame for a delivery (i.e. the delivery must be settled at the Sender at + the point the delivery has been completely transferred). If the negotiated value for snd-settle-mode at + attachment is unsettled, then this field MUST be false (or unset) on every transfer frame for a delivery + (unless the delivery is aborted). + :param bool more: Indicates that the Message has more content. + Note that if both the more and aborted fields are set to true, the aborted flag takes precedence. That is + a receiver should ignore the value of the more field if the transfer is marked as aborted. A sender + SHOULD NOT set the more flag to true if it also sets the aborted flag to true. + :param str rcv_settle_mode: If first, this indicates that the Receiver MUST settle the delivery once it has + arrived without waiting for the Sender to settle first. If second, this indicates that the Receiver MUST + NOT settle until sending its disposition to the Sender and receiving a settled disposition from the sender. + If not set, this value is defaulted to the value negotiated on link attach. If the negotiated link value is + first, then it is illegal to set this field to second. If the message is being sent settled by the Sender, + the value of this field is ignored. The (implicit or explicit) value of this field does not form part of the + transfer state, and is not retained if a link is suspended and subsequently resumed. + :param bytes state: The state of the delivery at the sender. + When set this informs the receiver of the state of the delivery at the sender. This is particularly useful + when transfers of unsettled deliveries are resumed after a resuming a link. Setting the state on the + transfer can be thought of as being equivalent to sending a disposition immediately before the transfer + performative, i.e. it is the state of the delivery (not the transfer) that existed at the point the frame + was sent. Note that if the transfer performative (or an earlier disposition performative referring to the + delivery) indicates that the delivery has attained a terminal state, then no future transfer or disposition + sent by the sender can alter that terminal state. + :param bool resume: Indicates a resumed delivery. + If true, the resume flag indicates that the transfer is being used to reassociate an unsettled delivery + from a dissociated link endpoint. The receiver MUST ignore resumed deliveries that are not in its local + unsettled map. The sender MUST NOT send resumed transfers for deliveries not in its local unsettledmap. + If a resumed delivery spans more than one transfer performative, then the resume flag MUST be set to true + on the first transfer of the resumed delivery. For subsequent transfers for the same delivery the resume + flag may be set to true, or may be omitted. In the case where the exchange of unsettled maps makes clear + that all message data has been successfully transferred to the receiver, and that only the final state + (andpotentially settlement) at the sender needs to be conveyed, then a resumed delivery may carry no + payload and instead act solely as a vehicle for carrying the terminal state of the delivery at the sender. + :param bool aborted: Indicates that the Message is aborted. + Aborted Messages should be discarded by the recipient (any payload within the frame carrying the performative + MUST be ignored). An aborted Message is implicitly settled. + :param bool batchable: Batchable hint. + If true, then the issuer is hinting that there is no need for the peer to urgently communicate updated + delivery state. This hint may be used to artificially increase the amount of batching an implementation + uses when communicating delivery states, and thereby save bandwidth. If the message being delivered is too + large to fit within a single frame, then the setting of batchable to true on any of the transfer + performatives for the delivery is equivalent to setting batchable to true for all the transfer performatives + for the delivery. The batchable value does not form part of the transfer state, and is not retained if a + link is suspended and subsequently resumed. + """ + + +DispositionFrame = namedtuple( + 'DispositionFrame', + [ + 'role', + 'first', + 'last', + 'settled', + 'state', + 'batchable' + ]) +DispositionFrame._code = 0x00000015 # type: ignore # pylint:disable=protected-access +DispositionFrame._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("role", AMQPTypes.boolean, True, None, False), + FIELD("first", AMQPTypes.uint, True, None, False), + FIELD("last", AMQPTypes.uint, False, None, False), + FIELD("settled", AMQPTypes.boolean, False, False, False), + FIELD("state", ObjDefinition.delivery_state, False, None, False), + FIELD("batchable", AMQPTypes.boolean, False, False, False)) +if _CAN_ADD_DOCSTRING: + DispositionFrame.__doc__ = """ + DISPOSITION performative. Inform remote peer of delivery state changes. + + The disposition frame is used to inform the remote peer of local changes in the state of deliveries. + The disposition frame may reference deliveries from many different links associated with a session, + although all links MUST have the directionality indicated by the specified role. Note that it is possible + for a disposition sent from sender to receiver to refer to a delivery which has not yet completed + (i.e. a delivery which is spread over multiple frames and not all frames have yet been sent). The use of such + interleaving is discouraged in favor of carrying the modified state on the next transfer performative for + the delivery. The disposition performative may refer to deliveries on links that are no longer attached. + As long as the links have not been closed or detached with an error then the deliveries are still "live" and + the updated state MUST be applied. + + :param str role: Directionality of disposition. + The role identifies whether the disposition frame contains information about sending link endpoints + or receiving link endpoints. + :param int first: Lower bound of deliveries. + Identifies the lower bound of delivery-ids for the deliveries in this set. + :param int last: Upper bound of deliveries. + Identifies the upper bound of delivery-ids for the deliveries in this set. If not set, + this is taken to be the same as first. + :param bool settled: Indicates deliveries are settled. + If true, indicates that the referenced deliveries are considered settled by the issuing endpoint. + :param bytes state: Indicates state of deliveries. + Communicates the state of all the deliveries referenced by this disposition. + :param bool batchable: Batchable hint. + If true, then the issuer is hinting that there is no need for the peer to urgently communicate the impact + of the updated delivery states. This hint may be used to artificially increase the amount of batching an + implementation uses when communicating delivery states, and thereby save bandwidth. + """ + +DetachFrame = namedtuple('DetachFrame', ['handle', 'closed', 'error']) +DetachFrame._code = 0x00000016 # type: ignore # pylint:disable=protected-access +DetachFrame._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("handle", AMQPTypes.uint, True, None, False), + FIELD("closed", AMQPTypes.boolean, False, False, False), + FIELD("error", ObjDefinition.error, False, None, False)) +if _CAN_ADD_DOCSTRING: + DetachFrame.__doc__ = """ + DETACH performative. Detach the Link Endpoint from the Session. + + Detach the Link Endpoint from the Session. This un-maps the handle and makes it available for + use by other Links + + :param int handle: The local handle of the link to be detached. + :param bool handle: If true then the sender has closed the link. + :param ~uamqp.error.AMQPError error: Error causing the detach. + If set, this field indicates that the Link is being detached due to an error condition. + The value of the field should contain details on the cause of the error. + """ + + +EndFrame = namedtuple('EndFrame', ['error']) +EndFrame._code = 0x00000017 # type: ignore # pylint:disable=protected-access +EndFrame._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + EndFrame.__doc__ = """ + END performative. End the Session. + + Indicates that the Session has ended. + + :param ~uamqp.error.AMQPError error: Error causing the end. + If set, this field indicates that the Session is being ended due to an error condition. + The value of the field should contain details on the cause of the error. + """ + + +CloseFrame = namedtuple('CloseFrame', ['error']) +CloseFrame._code = 0x00000018 # type: ignore # pylint:disable=protected-access +CloseFrame._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + CloseFrame.__doc__ = """ + CLOSE performative. Signal a Connection close. + + Sending a close signals that the sender will not be sending any more frames (or bytes of any other kind) on + the Connection. Orderly shutdown requires that this frame MUST be written by the sender. It is illegal to + send any more frames (or bytes of any other kind) after sending a close frame. + + :param ~uamqp.error.AMQPError error: Error causing the close. + If set, this field indicates that the Connection is being closed due to an error condition. + The value of the field should contain details on the cause of the error. + """ + + +SASLMechanism = namedtuple('SASLMechanism', ['sasl_server_mechanisms']) +SASLMechanism._code = 0x00000040 # type: ignore # pylint:disable=protected-access +SASLMechanism._definition = (FIELD('sasl_server_mechanisms', AMQPTypes.symbol, True, None, True),) # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + SASLMechanism.__doc__ = """ + Advertise available sasl mechanisms. + + dvertises the available SASL mechanisms that may be used for authentication. + + :param list(bytes) sasl_server_mechanisms: Supported sasl mechanisms. + A list of the sasl security mechanisms supported by the sending peer. + It is invalid for this list to be null or empty. If the sending peer does not require its partner to + authenticate with it, then it should send a list of one element with its value as the SASL mechanism + ANONYMOUS. The server mechanisms are ordered in decreasing level of preference. + """ + + +SASLInit = namedtuple('SASLInit', ['mechanism', 'initial_response', 'hostname']) +SASLInit._code = 0x00000041 # type: ignore # pylint:disable=protected-access +SASLInit._definition = ( # type: ignore # pylint:disable=protected-access + FIELD('mechanism', AMQPTypes.symbol, True, None, False), + FIELD('initial_response', AMQPTypes.binary, False, None, False), + FIELD('hostname', AMQPTypes.string, False, None, False)) +if _CAN_ADD_DOCSTRING: + SASLInit.__doc__ = """ + Initiate sasl exchange. + + Selects the sasl mechanism and provides the initial response if needed. + + :param bytes mechanism: Selected security mechanism. + The name of the SASL mechanism used for the SASL exchange. If the selected mechanism is not supported by + the receiving peer, it MUST close the Connection with the authentication-failure close-code. Each peer + MUST authenticate using the highest-level security profile it can handle from the list provided by the + partner. + :param bytes initial_response: Security response data. + A block of opaque data passed to the security mechanism. The contents of this data are defined by the + SASL security mechanism. + :param str hostname: The name of the target host. + The DNS name of the host (either fully qualified or relative) to which the sending peer is connecting. It + is not mandatory to provide the hostname. If no hostname is provided the receiving peer should select a + default based on its own configuration. This field can be used by AMQP proxies to determine the correct + back-end service to connect the client to, and to determine the domain to validate the client's credentials + against. This field may already have been specified by the server name indication extension as described + in RFC-4366, if a TLS layer is used, in which case this field SHOULD benull or contain the same value. + It is undefined what a different value to those already specific means. + """ + + +SASLChallenge = namedtuple('SASLChallenge', ['challenge']) +SASLChallenge._code = 0x00000042 # type: ignore # pylint:disable=protected-access +SASLChallenge._definition = (FIELD('challenge', AMQPTypes.binary, True, None, False),) # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + SASLChallenge.__doc__ = """ + Security mechanism challenge. + + Send the SASL challenge data as defined by the SASL specification. + + :param bytes challenge: Security challenge data. + Challenge information, a block of opaque binary data passed to the security mechanism. + """ + + +SASLResponse = namedtuple('SASLResponse', ['response']) +SASLResponse._code = 0x00000043 # type: ignore # pylint:disable=protected-access +SASLResponse._definition = (FIELD('response', AMQPTypes.binary, True, None, False),) # type: ignore # pylint:disable=protected-access +if _CAN_ADD_DOCSTRING: + SASLResponse.__doc__ = """ + Security mechanism response. + + Send the SASL response data as defined by the SASL specification. + + :param bytes response: Security response data. + """ + + +SASLOutcome = namedtuple('SASLOutcome', ['code', 'additional_data']) +SASLOutcome._code = 0x00000044 # type: ignore # pylint:disable=protected-access +SASLOutcome._definition = ( # type: ignore # pylint:disable=protected-access + FIELD('code', AMQPTypes.ubyte, True, None, False), + FIELD('additional_data', AMQPTypes.binary, False, None, False)) +if _CAN_ADD_DOCSTRING: + SASLOutcome.__doc__ = """ + Indicates the outcome of the sasl dialog. + + This frame indicates the outcome of the SASL dialog. Upon successful completion of the SASL dialog the + Security Layer has been established, and the peers must exchange protocol headers to either starta nested + Security Layer, or to establish the AMQP Connection. + + :param int code: Indicates the outcome of the sasl dialog. + A reply-code indicating the outcome of the SASL dialog. + :param bytes additional_data: Additional data as specified in RFC-4422. + The additional-data field carries additional data on successful authentication outcomeas specified by + the SASL specification (RFC-4422). If the authentication is unsuccessful, this field is not set. + """ diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/receiver.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/receiver.py new file mode 100644 index 0000000000000..5713f51b4b8cf --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/receiver.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import uuid +import logging +from typing import Optional, Union + +from ._decode import decode_payload +from .link import Link +from .constants import LinkState, Role +from .performatives import TransferFrame, DispositionFrame +from .outcomes import Received, Accepted, Rejected, Released, Modified + + +_LOGGER = logging.getLogger(__name__) + + +class ReceiverLink(Link): + def __init__(self, session, handle, source_address, **kwargs): + name = kwargs.pop("name", None) or str(uuid.uuid4()) + role = Role.Receiver + if "target_address" not in kwargs: + kwargs["target_address"] = "receiver-link-{}".format(name) + super(ReceiverLink, self).__init__(session, handle, name, role, source_address=source_address, **kwargs) + self._on_transfer = kwargs.pop("on_transfer") + self._received_payload = bytearray() + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # TODO: Assuming we establish all links for now... + # check link_create_from_endpoint in C lib + raise NotImplementedError("Pending") + + def _process_incoming_message(self, frame, message): + try: + return self._on_transfer(frame, message) + except Exception as e: # pylint: disable=broad-except + _LOGGER.error("Transfer callback function failed with error: %r", e, extra=self.network_trace_params) + return None + + def _incoming_attach(self, frame): + super(ReceiverLink, self)._incoming_attach(frame) + if frame[9] is None: # initial_delivery_count + _LOGGER.info("Cannot get initial-delivery-count. Detaching link", extra=self.network_trace_params) + self._set_state(LinkState.DETACHED) # TODO: Send detach now? + self.delivery_count = frame[9] + self.current_link_credit = self.link_credit + self._outgoing_flow() + + def _incoming_transfer(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", TransferFrame(payload=b"***", *frame[:-1]), extra=self.network_trace_params) + self.current_link_credit -= 1 + self.delivery_count += 1 + self.received_delivery_id = frame[1] # delivery_id + if not self.received_delivery_id and not self._received_payload: + pass # TODO: delivery error + if self._received_payload or frame[5]: # more + self._received_payload.extend(frame[11]) + if not frame[5]: + if self._received_payload: + message = decode_payload(memoryview(self._received_payload)) + self._received_payload = bytearray() + else: + message = decode_payload(frame[11]) + delivery_state = self._process_incoming_message(frame, message) + if not frame[4] and delivery_state: # settled + self._outgoing_disposition( + first=frame[1], + last=frame[1], + settled=True, + state=delivery_state, + batchable=None + ) + + def _wait_for_response(self, wait: Union[bool, float]) -> None: + if wait is True: + self._session._connection.listen(wait=False) # pylint: disable=protected-access + if self.state == LinkState.ERROR: + raise self._error + elif wait: + self._session._connection.listen(wait=wait) # pylint: disable=protected-access + if self.state == LinkState.ERROR: + raise self._error + + def _outgoing_disposition( + self, + first: int, + last: Optional[int], + settled: Optional[bool], + state: Optional[Union[Received, Accepted, Rejected, Released, Modified]], + batchable: Optional[bool], + ): + disposition_frame = DispositionFrame( + role=self.role, first=first, last=last, settled=settled, state=state, batchable=batchable + ) + if self.network_trace: + _LOGGER.debug("-> %r", DispositionFrame(*disposition_frame), extra=self.network_trace_params) + self._session._outgoing_disposition(disposition_frame) # pylint: disable=protected-access + + def attach(self): + super().attach() + self._received_payload = bytearray() + + def send_disposition( + self, + *, + wait: Union[bool, float] = False, + first_delivery_id: int, + last_delivery_id: Optional[int] = None, + settled: Optional[bool] = None, + delivery_state: Optional[Union[Received, Accepted, Rejected, Released, Modified]] = None, + batchable: Optional[bool] = None + ): + if self._is_closed: + raise ValueError("Link already closed.") + self._outgoing_disposition(first_delivery_id, last_delivery_id, settled, delivery_state, batchable) + self._wait_for_response(wait) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sasl.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sasl.py new file mode 100644 index 0000000000000..c4ff9d265540b --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sasl.py @@ -0,0 +1,146 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from ._transport import SSLTransport, WebSocketTransport, AMQPS_PORT +from .constants import SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT +from .performatives import SASLInit + + +_SASL_FRAME_TYPE = b"\x01" + + +class SASLPlainCredential(object): + """PLAIN SASL authentication mechanism. + See https://tools.ietf.org/html/rfc4616 for details + """ + + mechanism = b"PLAIN" + + def __init__(self, authcid, passwd, authzid=None): + self.authcid = authcid + self.passwd = passwd + self.authzid = authzid + + def start(self): + if self.authzid: + login_response = self.authzid.encode("utf-8") + else: + login_response = b"" + login_response += b"\0" + login_response += self.authcid.encode("utf-8") + login_response += b"\0" + login_response += self.passwd.encode("utf-8") + return login_response + + +class SASLAnonymousCredential(object): + """ANONYMOUS SASL authentication mechanism. + See https://tools.ietf.org/html/rfc4505 for details + """ + + mechanism = b"ANONYMOUS" + + def start(self): # pylint: disable=no-self-use + return b"" + + +class SASLExternalCredential(object): + """EXTERNAL SASL mechanism. + Enables external authentication, i.e. not handled through this protocol. + Only passes 'EXTERNAL' as authentication mechanism, but no further + authentication data. + """ + + mechanism = b"EXTERNAL" + + def start(self): # pylint: disable=no-self-use + return b"" + + +class SASLTransportMixin: + def _negotiate(self): + self.write(SASL_HEADER_FRAME) + _, returned_header = self.receive_frame() + if returned_header[1] != SASL_HEADER_FRAME: + raise ValueError( + f"""Mismatching AMQP header protocol. Expected: {SASL_HEADER_FRAME!r},""" + """received: {returned_header[1]!r}""" + ) + + _, supported_mechanisms = self.receive_frame(verify_frame_type=1) + if ( + self.credential.mechanism not in supported_mechanisms[1][0] + ): # sasl_server_mechanisms + raise ValueError( + "Unsupported SASL credential type: {}".format(self.credential.mechanism) + ) + sasl_init = SASLInit( + mechanism=self.credential.mechanism, + initial_response=self.credential.start(), + hostname=self.host, + ) + self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) + + _, next_frame = self.receive_frame(verify_frame_type=1) + frame_type, fields = next_frame + if frame_type != 0x00000044: # SASLOutcome + raise NotImplementedError("Unsupported SASL challenge") + if fields[0] == SASLCode.Ok: # code + return + raise ValueError( + "SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields) + ) + + +class SASLTransport(SSLTransport, SASLTransportMixin): + def __init__( + self, + host, + credential, + *, + port=AMQPS_PORT, + connect_timeout=None, + ssl_opts=None, + **kwargs, + ): + self.credential = credential + ssl_opts = ssl_opts or True + super(SASLTransport, self).__init__( + host, + port=port, + connect_timeout=connect_timeout, + ssl_opts=ssl_opts, + **kwargs, + ) + + def negotiate(self): + with self.block(): + self._negotiate() + + +class SASLWithWebSocket(WebSocketTransport, SASLTransportMixin): + def __init__( + self, + host, + credential, + *, + port=WEBSOCKET_PORT, + connect_timeout=None, + ssl_opts=None, + **kwargs, + ): + self.credential = credential + ssl_opts = ssl_opts or True + super().__init__( + host, + port=port, + connect_timeout=connect_timeout, + ssl_opts=ssl_opts, + **kwargs, + ) + + def negotiate(self): + self._negotiate() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sender.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sender.py new file mode 100644 index 0000000000000..26c78f5f9c170 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/sender.py @@ -0,0 +1,200 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import struct +import uuid +import logging +import time + +from ._encode import encode_payload +from .link import Link +from .constants import SessionTransferState, LinkDeliverySettleReason, LinkState, Role, SenderSettleMode, SessionState +from .error import AMQPLinkError, ErrorCondition, MessageException + +_LOGGER = logging.getLogger(__name__) + + +class PendingDelivery(object): + def __init__(self, **kwargs): + self.message = kwargs.get("message") + self.sent = False + self.frame = None + self.on_delivery_settled = kwargs.get("on_delivery_settled") + self.start = time.time() + self.transfer_state = None + self.timeout = kwargs.get("timeout") + self.settled = kwargs.get("settled", False) + self._network_trace_params = kwargs.get('network_trace_params') + + def on_settled(self, reason, state): + if self.on_delivery_settled and not self.settled: + try: + self.on_delivery_settled(reason, state) + except Exception as e: # pylint:disable=broad-except + _LOGGER.warning( + "Message 'on_send_complete' callback failed: %r", + e, + extra=self._network_trace_params + ) + self.settled = True + + +class SenderLink(Link): + def __init__(self, session, handle, target_address, **kwargs): + name = kwargs.pop("name", None) or str(uuid.uuid4()) + role = Role.Sender + if "source_address" not in kwargs: + kwargs["source_address"] = "sender-link-{}".format(name) + super(SenderLink, self).__init__(session, handle, name, role, target_address=target_address, **kwargs) + self._pending_deliveries = [] + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # TODO: Assuming we establish all links for now... + # check link_create_from_endpoint in C lib + raise NotImplementedError("Pending") + + # In theory we should not need to purge pending deliveries on attach/dettach - as a link should + # be resume-able, however this is not yet supported. + def _incoming_attach(self, frame): + try: + super(SenderLink, self)._incoming_attach(frame) + except AMQPLinkError: + self._remove_pending_deliveries() + raise + self.current_link_credit = self.link_credit + self._outgoing_flow() + self.update_pending_deliveries() + + def _incoming_detach(self, frame): + super(SenderLink, self)._incoming_detach(frame) + self._remove_pending_deliveries() + + def _incoming_flow(self, frame): + rcv_link_credit = frame[6] # link_credit + rcv_delivery_count = frame[5] # delivery_count + if frame[4] is not None: # handle + if rcv_link_credit is None or rcv_delivery_count is None: + _LOGGER.info( + "Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.", + extra=self.network_trace_params + ) + self._remove_pending_deliveries() + self._set_state(LinkState.DETACHED) # TODO: Send detach now? + else: + self.current_link_credit = rcv_delivery_count + rcv_link_credit - self.delivery_count + self.update_pending_deliveries() + + def _outgoing_transfer(self, delivery): + output = bytearray() + encode_payload(output, delivery.message) + delivery_count = self.delivery_count + 1 + delivery.frame = { + "handle": self.handle, + "delivery_tag": struct.pack(">I", abs(delivery_count)), + "message_format": delivery.message._code, # pylint:disable=protected-access + "settled": delivery.settled, + "more": False, + "rcv_settle_mode": None, + "state": None, + "resume": None, + "aborted": None, + "batchable": None, + "payload": output, + } + self._session._outgoing_transfer( # pylint:disable=protected-access + delivery, + self.network_trace_params if self.network_trace else None + ) + sent_and_settled = False + if delivery.transfer_state == SessionTransferState.OKAY: + self.delivery_count = delivery_count + self.current_link_credit -= 1 + delivery.sent = True + if delivery.settled: + delivery.on_settled(LinkDeliverySettleReason.SETTLED, None) + sent_and_settled = True + # elif delivery.transfer_state == SessionTransferState.ERROR: + # TODO: Session wasn't mapped yet - re-adding to the outgoing delivery queue? + return sent_and_settled + + def _incoming_disposition(self, frame): + if not frame[3]: # settled + return + range_end = (frame[2] or frame[1]) + 1 # first or last + settled_ids = list(range(frame[1], range_end)) + unsettled = [] + for delivery in self._pending_deliveries: + if delivery.sent and delivery.frame["delivery_id"] in settled_ids: + delivery.on_settled(LinkDeliverySettleReason.DISPOSITION_RECEIVED, frame[4]) # state + continue + unsettled.append(delivery) + self._pending_deliveries = unsettled + + def _remove_pending_deliveries(self): + for delivery in self._pending_deliveries: + delivery.on_settled(LinkDeliverySettleReason.NOT_DELIVERED, None) + self._pending_deliveries = [] + + def _on_session_state_change(self): + if self._session.state == SessionState.DISCARDING: + self._remove_pending_deliveries() + super()._on_session_state_change() + + def update_pending_deliveries(self): + if self.current_link_credit <= 0: + self.current_link_credit = self.link_credit + self._outgoing_flow() + now = time.time() + pending = [] + for delivery in self._pending_deliveries: + if delivery.timeout and (now - delivery.start) >= delivery.timeout: + delivery.on_settled(LinkDeliverySettleReason.TIMEOUT, None) + continue + if not delivery.sent: + sent_and_settled = self._outgoing_transfer(delivery) + if sent_and_settled: + continue + pending.append(delivery) + self._pending_deliveries = pending + + def send_transfer(self, message, *, send_async=False, **kwargs): + self._check_if_closed() + if self.state != LinkState.ATTACHED: + raise AMQPLinkError( + condition=ErrorCondition.ClientError, + description="Link is not attached." + ) + settled = self.send_settle_mode == SenderSettleMode.Settled + if self.send_settle_mode == SenderSettleMode.Mixed: + settled = kwargs.pop("settled", True) + delivery = PendingDelivery( + on_delivery_settled=kwargs.get("on_send_complete"), + timeout=kwargs.get("timeout"), + message=message, + settled=settled, + network_trace_params = self.network_trace_params + ) + if self.current_link_credit == 0 or send_async: + self._pending_deliveries.append(delivery) + else: + sent_and_settled = self._outgoing_transfer(delivery) + if not sent_and_settled: + self._pending_deliveries.append(delivery) + return delivery + + def cancel_transfer(self, delivery): + try: + index = self._pending_deliveries.index(delivery) + except ValueError: + raise ValueError("Found no matching pending transfer.") + delivery = self._pending_deliveries[index] + if delivery.sent: + raise MessageException( + ErrorCondition.ClientError, + message="Transfer cannot be cancelled. Message has already been sent and awaiting disposition.", + ) + delivery.on_settled(LinkDeliverySettleReason.CANCELLED, None) + self._pending_deliveries.pop(index) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/session.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/session.py new file mode 100644 index 0000000000000..bc5941f414fc7 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/session.py @@ -0,0 +1,507 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations +import uuid +import logging +import time +from typing import Union, Optional + +from .constants import ConnectionState, SessionState, SessionTransferState, Role +from .sender import SenderLink +from .receiver import ReceiverLink +from .management_link import ManagementLink +from .performatives import ( + BeginFrame, + EndFrame, + FlowFrame, + TransferFrame, + DispositionFrame, +) +from .error import AMQPError, ErrorCondition +from ._encode import encode_frame + +_LOGGER = logging.getLogger(__name__) + + +class Session(object): # pylint: disable=too-many-instance-attributes + """ + :param int remote_channel: The remote channel for this Session. + :param int next_outgoing_id: The transfer-id of the first transfer id the sender will send. + :param int incoming_window: The initial incoming-window of the sender. + :param int outgoing_window: The initial outgoing-window of the sender. + :param int handle_max: The maximum handle value that may be used on the Session. + :param list(str) offered_capabilities: The extension capabilities the sender supports. + :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports + :param dict properties: Session properties. + """ + + def __init__(self, connection, channel, **kwargs): + self.name = kwargs.pop("name", None) or str(uuid.uuid4()) + self.state = SessionState.UNMAPPED + self.handle_max = kwargs.get("handle_max", 4294967295) + self.properties = kwargs.pop("properties", None) + self.remote_properties = None + self.channel = channel + self.remote_channel = None + self.next_outgoing_id = kwargs.pop("next_outgoing_id", 0) + self.next_incoming_id = None + self.incoming_window = kwargs.pop("incoming_window", 1) + self.outgoing_window = kwargs.pop("outgoing_window", 1) + self.target_incoming_window = self.incoming_window + self.remote_incoming_window = 0 + self.remote_outgoing_window = 0 + self.offered_capabilities = None + self.desired_capabilities = kwargs.pop("desired_capabilities", None) + + self.allow_pipelined_open = kwargs.pop("allow_pipelined_open", True) + self.idle_wait_time = kwargs.get("idle_wait_time", 0.1) + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["amqpSession"] = self.name + + self.links = {} + self._connection = connection + self._output_handles = {} + self._input_handles = {} + + def __enter__(self): + self.begin() + return self + + def __exit__(self, *args): + self.end() + + @classmethod + def from_incoming_frame(cls, connection, channel): + # TODO: check session_create_from_endpoint in C lib + new_session = cls(connection, channel) + return new_session + + def _set_state(self, new_state): + # type: (SessionState) -> None + """Update the session state.""" + if new_state is None: + return + previous_state = self.state + self.state = new_state + _LOGGER.info( + "Session state changed: %r -> %r", + previous_state, + new_state, + extra=self.network_trace_params, + ) + for link in self.links.values(): + link._on_session_state_change() # pylint: disable=protected-access + + def _on_connection_state_change(self): + if self._connection.state in [ConnectionState.CLOSE_RCVD, ConnectionState.END]: + if self.state not in [SessionState.DISCARDING, SessionState.UNMAPPED]: + self._set_state(SessionState.DISCARDING) + + def _get_next_output_handle(self): + # type: () -> int + """Get the next available outgoing handle number within the max handle limit. + + :raises ValueError: If maximum handle has been reached. + :returns: The next available outgoing handle number. + :rtype: int + """ + if len(self._output_handles) >= self.handle_max: + raise ValueError( + "Maximum number of handles ({}) has been reached.".format( + self.handle_max + ) + ) + next_handle = next( + i for i in range(1, self.handle_max) if i not in self._output_handles + ) + return next_handle + + def _outgoing_begin(self): + begin_frame = BeginFrame( + remote_channel=self.remote_channel + if self.state == SessionState.BEGIN_RCVD + else None, + next_outgoing_id=self.next_outgoing_id, + outgoing_window=self.outgoing_window, + incoming_window=self.incoming_window, + handle_max=self.handle_max, + offered_capabilities=self.offered_capabilities + if self.state == SessionState.BEGIN_RCVD + else None, + desired_capabilities=self.desired_capabilities + if self.state == SessionState.UNMAPPED + else None, + properties=self.properties, + ) + if self.network_trace: + _LOGGER.debug("-> %r", begin_frame, extra=self.network_trace_params) + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, begin_frame + ) + + def _incoming_begin(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", BeginFrame(*frame), extra=self.network_trace_params) + self.handle_max = frame[4] # handle_max + self.next_incoming_id = frame[1] # next_outgoing_id + self.remote_incoming_window = frame[2] # incoming_window + self.remote_outgoing_window = frame[3] # outgoing_window + self.remote_properties = frame[7] + if self.state == SessionState.BEGIN_SENT: + self.remote_channel = frame[0] # remote_channel + self._set_state(SessionState.MAPPED) + elif self.state == SessionState.UNMAPPED: + self._set_state(SessionState.BEGIN_RCVD) + self._outgoing_begin() + self._set_state(SessionState.MAPPED) + + def _outgoing_end(self, error=None): + end_frame = EndFrame(error=error) + if self.network_trace: + _LOGGER.debug("-> %r", end_frame, extra=self.network_trace_params) + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, end_frame + ) + + def _incoming_end(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", EndFrame(*frame), extra=self.network_trace_params) + if self.state not in [ + SessionState.END_RCVD, + SessionState.END_SENT, + SessionState.DISCARDING, + ]: + self._set_state(SessionState.END_RCVD) + for _, link in self.links.items(): + link.detach() + # TODO: handling error + self._outgoing_end() + self._set_state(SessionState.UNMAPPED) + + def _outgoing_attach(self, frame): + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, frame + ) + + def _incoming_attach(self, frame): + try: + self._input_handles[frame[1]] = self.links[ + frame[0].decode("utf-8") + ] # name and handle + self._input_handles[frame[1]]._incoming_attach( # pylint: disable=protected-access + frame + ) + except KeyError: + try: + outgoing_handle = self._get_next_output_handle() + except ValueError: + _LOGGER.error( + "Unable to attach new link - cannot allocate more handles.", + extra=self.network_trace_params + ) + # detach the link that would have been set. + self.links[frame[0].decode("utf-8")].detach( + error=AMQPError( + condition=ErrorCondition.LinkDetachForced, + description="""Cannot allocate more handles, """ + """the max number of handles is {}. Detaching link""".format( + self.handle_max + ), + info=None, + ) + ) + return + if frame[2] == Role.Sender: # role + new_link = ReceiverLink.from_incoming_frame( + self, outgoing_handle, frame + ) + else: + new_link = SenderLink.from_incoming_frame(self, outgoing_handle, frame) + new_link._incoming_attach(frame) # pylint: disable=protected-access + self.links[frame[0]] = new_link + self._output_handles[outgoing_handle] = new_link + self._input_handles[frame[1]] = new_link + except ValueError as e: + # Reject Link + _LOGGER.error( + "Unable to attach new link: %r", + e, + extra=self.network_trace_params + ) + self._input_handles[frame[1]].detach() + + def _outgoing_flow(self, frame=None): + link_flow = frame or {} + link_flow.update( + { + "next_incoming_id": self.next_incoming_id, + "incoming_window": self.incoming_window, + "next_outgoing_id": self.next_outgoing_id, + "outgoing_window": self.outgoing_window, + } + ) + flow_frame = FlowFrame(**link_flow) + if self.network_trace: + _LOGGER.debug("-> %r", flow_frame, extra=self.network_trace_params) + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, flow_frame + ) + + def _incoming_flow(self, frame): + if self.network_trace: + _LOGGER.debug("<- %r", FlowFrame(*frame), extra=self.network_trace_params) + self.next_incoming_id = frame[2] # next_outgoing_id + remote_incoming_id = ( + frame[0] or self.next_outgoing_id + ) # next_incoming_id TODO "initial-outgoing-id" + self.remote_incoming_window = ( + remote_incoming_id + frame[1] - self.next_outgoing_id + ) # incoming_window + self.remote_outgoing_window = frame[3] # outgoing_window + if frame[4] is not None: # handle + self._input_handles[frame[4]]._incoming_flow( # pylint: disable=protected-access + frame + ) + else: + for link in self._output_handles.values(): + if ( + self.remote_incoming_window > 0 and not link._is_closed # pylint: disable=protected-access + ): + link._incoming_flow(frame) # pylint: disable=protected-access + + def _outgoing_transfer(self, delivery, network_trace_params): + if self.state != SessionState.MAPPED: + delivery.transfer_state = SessionTransferState.ERROR + if self.remote_incoming_window <= 0: + delivery.transfer_state = SessionTransferState.BUSY + else: + payload = delivery.frame["payload"] + payload_size = len(payload) + + delivery.frame["delivery_id"] = self.next_outgoing_id + # calculate the transfer frame encoding size excluding the payload + delivery.frame["payload"] = b"" + # TODO: encoding a frame would be expensive, we might want to improve depending on the perf test results + encoded_frame = encode_frame(TransferFrame(**delivery.frame))[1] + transfer_overhead_size = len(encoded_frame) + + # available size for payload per frame is calculated as following: + # remote max frame size - transfer overhead (calculated) - header (8 bytes) + available_frame_size = ( + self._connection._remote_max_frame_size - transfer_overhead_size - 8 # pylint: disable=protected-access + ) + + start_idx = 0 + remaining_payload_cnt = payload_size + # encode n-1 frames if payload_size > available_frame_size + while remaining_payload_cnt > available_frame_size: + tmp_delivery_frame = { + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": True, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "delivery_id": self.next_outgoing_id, + } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, + TransferFrame( + payload=payload[start_idx : start_idx + available_frame_size], + **tmp_delivery_frame + ) + ) + start_idx += available_frame_size + remaining_payload_cnt -= available_frame_size + + # encode the last frame + tmp_delivery_frame = { + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": False, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "delivery_id": self.next_outgoing_id, + } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, + TransferFrame(payload=payload[start_idx:], **tmp_delivery_frame) + ) + self.next_outgoing_id += 1 + self.remote_incoming_window -= 1 + self.outgoing_window -= 1 + # TODO: We should probably handle an error at the connection and update state accordingly + delivery.transfer_state = SessionTransferState.OKAY + + def _incoming_transfer(self, frame): + self.next_incoming_id += 1 + self.remote_outgoing_window -= 1 + self.incoming_window -= 1 + try: + self._input_handles[frame[0]]._incoming_transfer( # pylint: disable=protected-access + frame + ) + except KeyError: + _LOGGER.error( + "Received Transfer frame on unattached link. Ending session.", + extra=self.network_trace_params + ) + self._set_state(SessionState.DISCARDING) + self.end( + error=AMQPError( + condition=ErrorCondition.SessionUnattachedHandle, + description="""Invalid handle reference in received frame: """ + """Handle is not currently associated with an attached link""", + ) + ) + return + if self.incoming_window == 0: + self.incoming_window = self.target_incoming_window + self._outgoing_flow() + + def _outgoing_disposition(self, frame): + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, frame + ) + + def _incoming_disposition(self, frame): + if self.network_trace: + _LOGGER.debug( + "<- %r", DispositionFrame(*frame), extra=self.network_trace_params + ) + for link in self._input_handles.values(): + link._incoming_disposition(frame) # pylint: disable=protected-access + + def _outgoing_detach(self, frame): + self._connection._process_outgoing_frame( # pylint: disable=protected-access + self.channel, frame + ) + + def _incoming_detach(self, frame): + try: + link = self._input_handles[frame[0]] # handle + link._incoming_detach(frame) # pylint: disable=protected-access + # if link._is_closed: TODO + # self.links.pop(link.name, None) + # self._input_handles.pop(link.remote_handle, None) + # self._output_handles.pop(link.handle, None) + except KeyError: + self._set_state(SessionState.DISCARDING) + self._connection.close( + error=AMQPError( + condition=ErrorCondition.SessionUnattachedHandle, + description="""Invalid handle reference in received frame: """ + """Handle is not currently associated with an attached link""", + ) + ) + + def _wait_for_response(self, wait, end_state): + # type: (Union[bool, float], SessionState) -> None + if wait is True: + self._connection.listen(wait=False) + while self.state != end_state: + time.sleep(self.idle_wait_time) + self._connection.listen(wait=False) + elif wait: + self._connection.listen(wait=False) + timeout = time.time() + wait + while self.state != end_state: + if time.time() >= timeout: + break + time.sleep(self.idle_wait_time) + self._connection.listen(wait=False) + + def begin(self, wait=False): + self._outgoing_begin() + self._set_state(SessionState.BEGIN_SENT) + if wait: + self._wait_for_response(wait, SessionState.BEGIN_SENT) + elif not self.allow_pipelined_open: + raise ValueError( + "Connection has been configured to not allow piplined-open. Please set 'wait' parameter." + ) + + def end(self, error=None, wait=False): + # type: (Optional[AMQPError], bool) -> None + try: + if self.state not in [SessionState.UNMAPPED, SessionState.DISCARDING]: + self._outgoing_end(error=error) + for _, link in self.links.items(): + link.detach() + new_state = SessionState.DISCARDING if error else SessionState.END_SENT + self._set_state(new_state) + self._wait_for_response(wait, SessionState.UNMAPPED) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.info("An error occurred when ending the session: %r", exc, extra=self.network_trace_params) + self._set_state(SessionState.UNMAPPED) + + def create_receiver_link(self, source_address, **kwargs): + assigned_handle = self._get_next_output_handle() + link = ReceiverLink( + self, + handle=assigned_handle, + source_address=source_address, + network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), + **kwargs, + ) + self.links[link.name] = link + self._output_handles[assigned_handle] = link + return link + + def create_sender_link(self, target_address, **kwargs): + assigned_handle = self._get_next_output_handle() + link = SenderLink( + self, + handle=assigned_handle, + target_address=target_address, + network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), + **kwargs, + ) + self._output_handles[assigned_handle] = link + self.links[link.name] = link + return link + + def create_request_response_link_pair(self, endpoint, **kwargs): + return ManagementLink( + self, + endpoint, + network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), + **kwargs, + ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/types.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/types.py new file mode 100644 index 0000000000000..db478af591c8c --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/types.py @@ -0,0 +1,90 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +from enum import Enum + + +TYPE = 'TYPE' +VALUE = 'VALUE' + + +class AMQPTypes(object): # pylint: disable=no-init + null = 'NULL' + boolean = 'BOOL' + ubyte = 'UBYTE' + byte = 'BYTE' + ushort = 'USHORT' + short = 'SHORT' + uint = 'UINT' + int = 'INT' + ulong = 'ULONG' + long = 'LONG' + float = 'FLOAT' + double = 'DOUBLE' + timestamp = 'TIMESTAMP' + uuid = 'UUID' + binary = 'BINARY' + string = 'STRING' + symbol = 'SYMBOL' + list = 'LIST' + map = 'MAP' + array = 'ARRAY' + described = 'DESCRIBED' + + +class FieldDefinition(Enum): + fields = "fields" + annotations = "annotations" + message_id = "message-id" + app_properties = "application-properties" + node_properties = "node-properties" + filter_set = "filter-set" + + +class ObjDefinition(Enum): + source = "source" + target = "target" + delivery_state = "delivery-state" + error = "error" + + +class ConstructorBytes(object): # pylint: disable=no-init + null = b'\x40' + bool = b'\x56' + bool_true = b'\x41' + bool_false = b'\x42' + ubyte = b'\x50' + byte = b'\x51' + ushort = b'\x60' + short = b'\x61' + uint_0 = b'\x43' + uint_small = b'\x52' + int_small = b'\x54' + uint_large = b'\x70' + int_large = b'\x71' + ulong_0 = b'\x44' + ulong_small = b'\x53' + long_small = b'\x55' + ulong_large = b'\x80' + long_large = b'\x81' + float = b'\x72' + double = b'\x82' + timestamp = b'\x83' + uuid = b'\x98' + binary_small = b'\xA0' + binary_large = b'\xB0' + string_small = b'\xA1' + string_large = b'\xB1' + symbol_small = b'\xA3' + symbol_large = b'\xB3' + list_0 = b'\x45' + list_small = b'\xC0' + list_large = b'\xD0' + map_small = b'\xC1' + map_large = b'\xD1' + array_small = b'\xE0' + array_large = b'\xF0' + descriptor = b'\x00' diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/utils.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/utils.py new file mode 100644 index 0000000000000..d8bdfdc3dc5bb --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/utils.py @@ -0,0 +1,138 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- +import datetime +from base64 import b64encode +from hashlib import sha256 +from hmac import HMAC +from urllib.parse import urlencode, quote_plus +import time + +from .types import TYPE, VALUE, AMQPTypes +from ._encode import encode_payload + + +class UTC(datetime.tzinfo): + """Time Zone info for handling UTC""" + + def utcoffset(self, dt): + """UTF offset for UTC is 0.""" + return datetime.timedelta(0) + + def tzname(self, dt): + """Timestamp representation.""" + return "Z" + + def dst(self, dt): + """No daylight saving for UTC.""" + return datetime.timedelta(hours=1) + + +try: + from datetime import timezone # pylint: disable=ungrouped-imports + + TZ_UTC = timezone.utc # type: ignore +except ImportError: + TZ_UTC = UTC() # type: ignore + + +def utc_from_timestamp(timestamp): + return datetime.datetime.fromtimestamp(timestamp, tz=TZ_UTC) + + +def utc_now(): + return datetime.datetime.now(tz=TZ_UTC) + + +def encode(value, encoding='UTF-8'): + return value.encode(encoding) if isinstance(value, str) else value + + +def generate_sas_token(audience, policy, key, expiry=None): + """ + Generate a sas token according to the given audience, policy, key and expiry + + :param str audience: + :param str policy: + :param str key: + :param int expiry: abs expiry time + :rtype: str + """ + if not expiry: + expiry = int(time.time()) + 3600 # Default to 1 hour. + + encoded_uri = quote_plus(audience) + encoded_policy = quote_plus(policy).encode("utf-8") + encoded_key = key.encode("utf-8") + + ttl = int(expiry) + sign_key = '%s\n%d' % (encoded_uri, ttl) + signature = b64encode(HMAC(encoded_key, sign_key.encode('utf-8'), sha256).digest()) + result = { + 'sr': audience, + 'sig': signature, + 'se': str(ttl) + } + if policy: + result['skn'] = encoded_policy + return 'SharedAccessSignature ' + urlencode(result) + + +def add_batch(batch, message): + # Add a message to a batch + output = bytearray() + encode_payload(output, message) + batch[5].append(output) + + +def encode_str(data, encoding='utf-8'): + try: + return data.encode(encoding) + except AttributeError: + return data + + +def normalized_data_body(data, **kwargs): + # A helper method to normalize input into AMQP Data Body format + encoding = kwargs.get("encoding", "utf-8") + if isinstance(data, list): + return [encode_str(item, encoding) for item in data] + return [encode_str(data, encoding)] + + +def normalized_sequence_body(sequence): + # A helper method to normalize input into AMQP Sequence Body format + if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): + return sequence + if isinstance(sequence, list): + return [sequence] + + +def get_message_encoded_size(message): + output = bytearray() + encode_payload(output, message) + return len(output) + + +def amqp_long_value(value): + # A helper method to wrap a Python int as AMQP long + # TODO: wrapping one line in a function is expensive, find if there's a better way to do it + return {TYPE: AMQPTypes.long, VALUE: value} + + +def amqp_uint_value(value): + # A helper method to wrap a Python int as AMQP uint + return {TYPE: AMQPTypes.uint, VALUE: value} + + +def amqp_string_value(value): + return {TYPE: AMQPTypes.string, VALUE: value} + + +def amqp_symbol_value(value): + return {TYPE: AMQPTypes.symbol, VALUE: value} + +def amqp_array_value(value): + return {TYPE: AMQPTypes.array, VALUE: value} diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py index eff61d6c79bff..dc7f4f25f64e9 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py @@ -6,8 +6,7 @@ import logging from weakref import WeakSet from typing_extensions import Literal - -import uamqp +import certifi from ._base_handler import ( _parse_conn_str, @@ -29,6 +28,8 @@ ServiceBusSessionFilter, ) +from ._transport._pyamqp_transport import PyamqpTransport + if TYPE_CHECKING: from azure.core.credentials import ( TokenCredential, @@ -42,7 +43,7 @@ _LOGGER = logging.getLogger(__name__) -class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-keyword +class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """The ServiceBusClient class defines a high level interface for getting ServiceBusSender and ServiceBusReceiver. @@ -84,6 +85,9 @@ class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-key :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is + False and the Pure Python AMQP library will be used as the underlying transport. + :paramtype uamqp_transport: bool .. admonition:: Example: @@ -109,50 +113,63 @@ def __init__( retry_mode: str = "exponential", **kwargs: Any ) -> None: + uamqp_transport = kwargs.pop("uamqp_transport", False) + if uamqp_transport: + try: + from ._transport._uamqp_transport import UamqpTransport + except ImportError: + raise ValueError("To use the uAMQP transport, please install `uamqp>=1.6.3,<2.0.0`.") + self._amqp_transport = UamqpTransport if uamqp_transport else PyamqpTransport + # If the user provided http:// or sb://, let's be polite and strip that. self.fully_qualified_namespace = strip_protocol_from_uri( fully_qualified_namespace.strip() ) self._credential = credential + # TODO: can we remove this here? it's recreated in Sender/Receiver self._config = Configuration( retry_total=retry_total, retry_backoff_factor=retry_backoff_factor, retry_backoff_max=retry_backoff_max, retry_mode=retry_mode, + hostname=self.fully_qualified_namespace, + amqp_transport=self._amqp_transport, **kwargs ) self._connection = None # Optional entity name, can be the name of Queue or Topic. Intentionally not advertised, typically be needed. self._entity_name = kwargs.get("entity_name") - self._auth_uri = "sb://{}".format(self.fully_qualified_namespace) + self._auth_uri = f"sb://{self.fully_qualified_namespace}" if self._entity_name: - self._auth_uri = "{}/{}".format(self._auth_uri, self._entity_name) + self._auth_uri = f"{self._auth_uri}/{self._entity_name}" # Internal flag for switching whether to apply connection sharing, pending fix in uamqp library self._connection_sharing = False - self._handlers = WeakSet() # type: WeakSet - + self._handlers: WeakSet = WeakSet() self._custom_endpoint_address = kwargs.get('custom_endpoint_address') self._connection_verify = kwargs.get("connection_verify") def __enter__(self): if self._connection_sharing: - self._create_uamqp_connection() + self._create_connection() return self def __exit__(self, *args): self.close() - def _create_uamqp_connection(self): + def _create_connection(self): auth = create_authentication(self) - self._connection = uamqp.Connection( - hostname=self.fully_qualified_namespace, - sasl=auth, - debug=self._config.logging_enable, + self._connection = self._amqp_transport.create_connection( + host=self.fully_qualified_namespace, + auth=auth.sasl, + network_trace=self._config.logging_enable, + custom_endpoint_address=self._custom_endpoint_address, + ssl_opts={'ca_certs': self._connection_verify or certifi.where()}, + transport_type=self._config.transport_type, + http_proxy=self._config.http_proxy, ) - def close(self): - # type: () -> None + def close(self) -> None: """ Close down the ServiceBus client. All spawned senders, receivers and underlying connection will be shutdown. @@ -172,7 +189,7 @@ def close(self): self._handlers.clear() if self._connection_sharing and self._connection: - self._connection.destroy() + self._connection.close() @classmethod def from_connection_string( @@ -215,6 +232,9 @@ def from_connection_string( :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is + False and the Pure Python AMQP library will be used as the underlying transport. + :paramtype uamqp_transport: bool :rtype: ~azure.servicebus.ServiceBusClient .. admonition:: Example: @@ -245,8 +265,11 @@ def from_connection_string( **kwargs ) - def get_queue_sender(self, queue_name, **kwargs): - # type: (str, Any) -> ServiceBusSender + def get_queue_sender( + self, + queue_name: str, + **kwargs: Any + ) -> ServiceBusSender: """Get ServiceBusSender for the specific queue. :param str queue_name: The path of specific Service Bus Queue the client connects to. @@ -288,6 +311,7 @@ def get_queue_sender(self, queue_name, **kwargs): retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) @@ -402,6 +426,7 @@ def get_queue_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) @@ -449,6 +474,7 @@ def get_topic_sender(self, topic_name, **kwargs): retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) @@ -562,6 +588,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) except ValueError: @@ -591,6 +618,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py index d34a6a688eb6a..0fe7bb974b60b 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py @@ -10,19 +10,20 @@ import uuid import datetime import warnings +from enum import Enum from typing import Any, List, Optional, Dict, Iterator, Union, TYPE_CHECKING, cast -from uamqp import ReceiveClient, types, Message -from uamqp.constants import SenderSettleMode -from uamqp.authentication.common import AMQPAuth - from .exceptions import ServiceBusError from ._base_handler import BaseHandler from ._common.message import ServiceBusReceivedMessage -from ._common.utils import ( - create_authentication, +from ._common.utils import create_authentication +from ._common.tracing import ( get_receive_links, receive_trace_context_manager, + settle_trace_context_manager, + get_span_link_from_message, + SPAN_NAME_RECEIVE_DEFERRED, + SPAN_NAME_PEEK, ) from ._common.constants import ( CONSUMER_IDENTIFIER, @@ -37,8 +38,6 @@ MGMT_REQUEST_RECEIVER_SETTLE_MODE, MGMT_REQUEST_FROM_SEQUENCE_NUMBER, MGMT_REQUEST_MAX_MESSAGE_COUNT, - SPAN_NAME_RECEIVE_DEFERRED, - SPAN_NAME_PEEK, MESSAGE_COMPLETE, MESSAGE_ABANDON, MESSAGE_DEFER, @@ -48,7 +47,6 @@ MGMT_REQUEST_DEAD_LETTER_REASON, MGMT_REQUEST_DEAD_LETTER_ERROR_DESCRIPTION, MGMT_RESPONSE_MESSAGE_EXPIRATION, - ServiceBusToAMQPReceiveModeMap, ) from ._common import mgmt_handlers from ._common.receiver_mixins import ReceiverMixin @@ -56,6 +54,16 @@ from ._servicebus_session import ServiceBusSession if TYPE_CHECKING: + try: + # pylint:disable=unused-import + from uamqp import ReceiveClient as uamqp_ReceiveClientSync, Message as uamqp_Message + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + except ImportError: + pass + from ._transport._base import AmqpTransport + from ._pyamqp.client import ReceiveClient as pyamqp_ReceiveClientSync + from ._pyamqp.message import Message as pyamqp_Message + from ._pyamqp.authentication import JWTTokenAuth as pyamqp_JWTTokenAuth from ._common.auto_lock_renewer import AutoLockRenewer from azure.core.credentials import ( TokenCredential, @@ -147,7 +155,9 @@ def __init__( prefetch_count: int = 0, **kwargs: Any, ) -> None: - self._message_iter = None # type: Optional[Iterator[ServiceBusReceivedMessage]] + self._session_id = None + self._message_iter: Optional[Iterator[ServiceBusReceivedMessage]] = None + self._amqp_transport: "AmqpTransport" if kwargs.get("entity_name"): super(ServiceBusReceiver, self).__init__( fully_qualified_namespace=fully_qualified_namespace, @@ -203,44 +213,32 @@ def __init__( self._session = ( None if self._session_id is None - else ServiceBusSession(self._session_id, self) + else ServiceBusSession(cast(str, self._session_id), self) ) self._receive_context = threading.Event() + self._handler: Union["pyamqp_ReceiveClientSync", "uamqp_ReceiveClientSync"] + self._build_received_message = functools.partial( + self._amqp_transport.build_received_message, + self, + ServiceBusReceivedMessage + ) + self._iter_contextual_wrapper = functools.partial( + self._amqp_transport.iter_contextual_wrapper, self + ) + self._iter_next = functools.partial( + self._amqp_transport.iter_next, + self + ) def __iter__(self): return self._iter_contextual_wrapper() - def _iter_contextual_wrapper(self, max_wait_time=None): - """The purpose of this wrapper is to allow both state restoration (for multiple concurrent iteration) - and per-iter argument passing that requires the former.""" - # pylint: disable=protected-access - original_timeout = None - while True: - # This is not threadsafe, but gives us a way to handle if someone passes - # different max_wait_times to different iterators and uses them in concert. - if max_wait_time: - original_timeout = self._handler._timeout - self._handler._timeout = max_wait_time * 1000 - try: - message = self._inner_next() - links = get_receive_links(message) - with receive_trace_context_manager(self, links=links): - yield message - except StopIteration: - break - finally: - if original_timeout: - try: - self._handler._timeout = original_timeout - except AttributeError: # Handler may be disposed already. - pass - - def _inner_next(self): + def _inner_next(self, wait_time=None): # We do this weird wrapping such that an imperitive next() call, and a generator-based iter both trace sanely. self._check_live() while True: try: - return self._do_retryable_operation(self._iter_next) + return self._do_retryable_operation(self._iter_next, wait_time=wait_time) except StopIteration: self._message_iter = None raise @@ -258,27 +256,10 @@ def __next__(self): next = __next__ # for python2.7 - def _iter_next(self): - try: - self._receive_context.set() - self._open() - if not self._message_iter: - self._message_iter = self._handler.receive_messages_iter() - uamqp_message = next(self._message_iter) - message = self._build_message(uamqp_message) - if ( - self._auto_lock_renewer - and not self._session - and self._receive_mode != ServiceBusReceiveMode.RECEIVE_AND_DELETE - ): - self._auto_lock_renewer.register(self, message) - return message - finally: - self._receive_context.clear() - @classmethod - def _from_connection_string(cls, conn_str, **kwargs): - # type: (str, Any) -> ServiceBusReceiver + def _from_connection_string( + cls, conn_str: str, **kwargs: Any + ) -> "ServiceBusReceiver": """Create a ServiceBusReceiver from a connection string. :param conn_str: The connection string of a Service Bus. @@ -337,24 +318,21 @@ def _from_connection_string(cls, conn_str, **kwargs): ) return cls(**constructor_args) - def _create_handler(self, auth): - # type: (AMQPAuth) -> None - self._handler = ReceiveClient( - self._get_source(), + def _create_handler(self, auth: Union["pyamqp_JWTTokenAuth", "uamqp_JWTTokenAuth"]) -> None: + + self._handler = self._amqp_transport.create_receive_client( + receiver=self, + source=self._get_source(), auth=auth, - debug=self._config.logging_enable, + network_trace=self._config.logging_enable, properties=self._properties, - error_policy=self._error_policy, + retry_policy=self._error_policy, client_name=self._name, - on_attach=self._on_attach, - auto_complete=False, - encoding=self._config.encoding, - receive_settle_mode=ServiceBusToAMQPReceiveModeMap[self._receive_mode], - send_settle_mode=SenderSettleMode.Settled - if self._receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE - else None, - timeout=self._max_wait_time * 1000 if self._max_wait_time else 0, - prefetch=self._prefetch_count, + receive_mode=self._receive_mode, + timeout=self._max_wait_time * self._amqp_transport.TIMEOUT_FACTOR + if self._max_wait_time + else 0, + link_credit=self._prefetch_count, # If prefetch is 1, then keep_alive coroutine serves as keep receiving for releasing messages keep_alive_interval=self._config.keep_alive if self._prefetch_count != 1 @@ -363,7 +341,11 @@ def _create_handler(self, auth): link_properties={CONSUMER_IDENTIFIER: self._name}, ) if self._prefetch_count == 1: - self._handler._message_received = self._enhanced_message_received # pylint: disable=protected-access + # pylint: disable=protected-access + self._handler._message_received = functools.partial( + self._amqp_transport.enhanced_message_received, # type: ignore[attr-defined] + self + ) def _open(self): # pylint: disable=protected-access @@ -386,8 +368,9 @@ def _open(self): if self._auto_lock_renewer and self._session: self._auto_lock_renewer.register(self, self.session) - def _receive(self, max_message_count=None, timeout=None): - # type: (Optional[int], Optional[float]) -> List[ServiceBusReceivedMessage] + def _receive( + self, max_message_count: Optional[int] = None, timeout: Optional[float] = None + ) -> List[ServiceBusReceivedMessage]: # pylint: disable=protected-access try: self._receive_context.set() @@ -396,24 +379,24 @@ def _receive(self, max_message_count=None, timeout=None): amqp_receive_client = self._handler received_messages_queue = amqp_receive_client._received_messages max_message_count = max_message_count or self._prefetch_count - timeout_ms = ( - 1000 * (timeout or self._max_wait_time) + timeout_time = ( + self._amqp_transport.TIMEOUT_FACTOR * (timeout or self._max_wait_time) if (timeout or self._max_wait_time) else 0 ) - abs_timeout_ms = ( - amqp_receive_client._counter.get_current_ms() + timeout_ms - if timeout_ms + abs_timeout = ( + self._amqp_transport.get_current_time(amqp_receive_client) + timeout_time + if (timeout_time) else 0 ) - batch = [] # type: List[Message] + batch: Union[List["uamqp_Message"], List["pyamqp_Message"]] = [] while ( not received_messages_queue.empty() and len(batch) < max_message_count ): batch.append(received_messages_queue.get()) received_messages_queue.task_done() if len(batch) >= max_message_count: - return [self._build_message(message) for message in batch] + return [self._build_received_message(message) for message in batch] # Dynamically issue link credit if max_message_count > 1 when the prefetch_count is the default value 1 if ( @@ -422,18 +405,16 @@ def _receive(self, max_message_count=None, timeout=None): and max_message_count > 1 ): link_credit_needed = max_message_count - len(batch) - amqp_receive_client.message_handler.reset_link_credit( - link_credit_needed - ) + self._amqp_transport.reset_link_credit(amqp_receive_client, link_credit_needed) first_message_received = expired = False receiving = True while receiving and not expired and len(batch) < max_message_count: while receiving and received_messages_queue.qsize() < max_message_count: if ( - abs_timeout_ms - and amqp_receive_client._counter.get_current_ms() - > abs_timeout_ms + abs_timeout + and self._amqp_transport.get_current_time(amqp_receive_client) + > abs_timeout ): expired = True break @@ -447,9 +428,9 @@ def _receive(self, max_message_count=None, timeout=None): ): # first message(s) received, continue receiving for some time first_message_received = True - abs_timeout_ms = ( - amqp_receive_client._counter.get_current_ms() - + self._further_pull_receive_timeout_ms + abs_timeout = ( + self._amqp_transport.get_current_time(amqp_receive_client) + + self._further_pull_receive_timeout ) while ( not received_messages_queue.empty() @@ -458,7 +439,7 @@ def _receive(self, max_message_count=None, timeout=None): batch.append(received_messages_queue.get()) received_messages_queue.task_done() - return [self._build_message(message) for message in batch] + return [self._build_received_message(message) for message in batch] finally: self._receive_context.clear() @@ -489,36 +470,37 @@ def _settle_message_with_retry( message="The lock on the message lock has expired.", error=message.auto_renew_error, ) - - self._do_retryable_operation( - self._settle_message, - timeout=None, - message=message, - settle_operation=settle_operation, - dead_letter_reason=dead_letter_reason, - dead_letter_error_description=dead_letter_error_description, - ) - - message._settled = True + link = get_span_link_from_message(message) + trace_links = [link] if link else [] + with settle_trace_context_manager(self, settle_operation, links=trace_links): + self._do_retryable_operation( + self._settle_message, + timeout=None, + message=message, + settle_operation=settle_operation, + dead_letter_reason=dead_letter_reason, + dead_letter_error_description=dead_letter_error_description, + ) + message._settled = True def _settle_message( self, - message, - settle_operation, - dead_letter_reason=None, - dead_letter_error_description=None, - ): - # type: (ServiceBusReceivedMessage, str, Optional[str], Optional[str]) -> None + message: ServiceBusReceivedMessage, + settle_operation: str, + dead_letter_reason: Optional[str] = None, + dead_letter_error_description: Optional[str] = None, + ) -> None: # pylint: disable=protected-access try: if not message._is_deferred_message: try: - self._settle_message_via_receiver_link( + self._amqp_transport.settle_message_via_receiver_link( + self._handler, message, settle_operation, dead_letter_reason=dead_letter_reason, dead_letter_error_description=dead_letter_error_description, - )() + ) return except RuntimeError as exception: _LOGGER.info( @@ -550,12 +532,14 @@ def _settle_message( raise def _settle_message_via_mgmt_link( - self, settlement, lock_tokens, dead_letter_details=None - ): - # type: (str, List[Union[uuid.UUID, str]], Optional[Dict[str, Any]]) -> Any + self, + settlement: str, + lock_tokens: List[Union[uuid.UUID, str]], + dead_letter_details: Optional[Dict[str, Any]] = None + ) -> Any: message = { MGMT_REQUEST_DISPOSITION_STATUS: settlement, - MGMT_REQUEST_LOCK_TOKENS: types.AMQPArray(lock_tokens), + MGMT_REQUEST_LOCK_TOKENS: self._amqp_transport.AMQP_ARRAY_VALUE(lock_tokens), } self._populate_message_properties(message) @@ -567,10 +551,10 @@ def _settle_message_via_mgmt_link( REQUEST_RESPONSE_UPDATE_DISPOSTION_OPERATION, message, mgmt_handlers.default ) - def _renew_locks(self, *lock_tokens, **kwargs): - # type: (str, Any) -> Any + + def _renew_locks(self, *lock_tokens: str, **kwargs: Any) -> Any: timeout = kwargs.pop("timeout", None) - message = {MGMT_REQUEST_LOCK_TOKENS: types.AMQPArray(lock_tokens)} + message = {MGMT_REQUEST_LOCK_TOKENS: self._amqp_transport.AMQP_ARRAY_VALUE(lock_tokens)} return self._mgmt_request_response_with_retry( REQUEST_RESPONSE_RENEWLOCK_OPERATION, message, @@ -583,8 +567,7 @@ def _close_handler(self): super(ServiceBusReceiver, self)._close_handler() @property - def session(self): - # type: () -> ServiceBusSession + def session(self) -> ServiceBusSession: """ Get the ServiceBusSession object linked with the receiver. Session is only available to session-enabled entities, it would return None if called on a non-sessionful receiver. @@ -602,13 +585,13 @@ def session(self): """ return self._session # type: ignore - def close(self): - # type: () -> None + def close(self) -> None: super(ServiceBusReceiver, self).close() self._message_iter = None # pylint: disable=attribute-defined-outside-init - def _get_streaming_message_iter(self, max_wait_time=None): - # type: (Optional[float]) -> Iterator[ServiceBusReceivedMessage] + def _get_streaming_message_iter( + self, max_wait_time: Optional[float] = None + ) -> Iterator[ServiceBusReceivedMessage]: """Receive messages from an iterator indefinitely, or if a max_wait_time is specified, until such a timeout occurs. @@ -676,6 +659,7 @@ def receive_messages( raise ValueError("The max_wait_time must be greater than 0.") if max_message_count is not None and max_message_count <= 0: raise ValueError("The max_message_count must be greater than 0") + start_time = time.time_ns() messages = self._do_retryable_operation( self._receive, max_message_count=max_message_count, @@ -683,7 +667,7 @@ def receive_messages( operation_requires_timeout=True, ) links = get_receive_links(messages) - with receive_trace_context_manager(self, links=links): + with receive_trace_context_manager(self, links=links, start_time=start_time): if ( self._auto_lock_renewer and not self._session @@ -732,16 +716,16 @@ def receive_deferred_messages( if len(sequence_numbers) == 0: return [] # no-op on empty list. self._open() - uamqp_receive_mode = ServiceBusToAMQPReceiveModeMap[self._receive_mode] + amqp_receive_mode = self._amqp_transport.ServiceBusToAMQPReceiveModeMap[self._receive_mode] try: - receive_mode = uamqp_receive_mode.value.value + receive_mode = cast(Enum, amqp_receive_mode).value except AttributeError: - receive_mode = int(uamqp_receive_mode.value) + receive_mode = int(amqp_receive_mode) message = { - MGMT_REQUEST_SEQUENCE_NUMBERS: types.AMQPArray( - [types.AMQPLong(s) for s in sequence_numbers] + MGMT_REQUEST_SEQUENCE_NUMBERS: self._amqp_transport.AMQP_ARRAY_VALUE( + [self._amqp_transport.AMQP_LONG_VALUE(s) for s in sequence_numbers] ), - MGMT_REQUEST_RECEIVER_SETTLE_MODE: types.AMQPuInt(receive_mode), + MGMT_REQUEST_RECEIVER_SETTLE_MODE: self._amqp_transport.AMQP_UINT_VALUE(receive_mode), } self._populate_message_properties(message) @@ -750,7 +734,9 @@ def receive_deferred_messages( mgmt_handlers.deferred_message_op, receive_mode=self._receive_mode, receiver=self, + amqp_transport=self._amqp_transport, ) + start_time = time.time_ns() messages = self._mgmt_request_response_with_retry( REQUEST_RESPONSE_RECEIVE_BY_SEQUENCE_NUMBER, message, @@ -759,7 +745,7 @@ def receive_deferred_messages( ) links = get_receive_links(messages) with receive_trace_context_manager( - self, span_name=SPAN_NAME_RECEIVE_DEFERRED, links=links + self, span_name=SPAN_NAME_RECEIVE_DEFERRED, links=links, start_time=start_time ): if ( self._auto_lock_renewer @@ -783,10 +769,8 @@ def peek_messages( Peeked messages are not removed from queue, nor are they locked. They cannot be completed, deferred or dead-lettered. - For more information about message browsing see https://aka.ms/azsdk/servicebus/message-browsing - - :param int max_message_count: The maximum number of messages to try and peek. The actual number of messages - returned may be fewer and are subject to service limits. The default value is 1. + :param int max_message_count: The maximum number of messages to try and peek. The default + value is 1. :keyword int sequence_number: A message sequence number from which to start browsing messages. :keyword Optional[float] timeout: The total operation timeout in seconds including all the retries. The value must be greater than 0 if specified. The default value is None, meaning no timeout. @@ -815,17 +799,18 @@ def peek_messages( self._open() message = { - MGMT_REQUEST_FROM_SEQUENCE_NUMBER: types.AMQPLong(sequence_number), + MGMT_REQUEST_FROM_SEQUENCE_NUMBER: self._amqp_transport.AMQP_LONG_VALUE(sequence_number), MGMT_REQUEST_MAX_MESSAGE_COUNT: max_message_count, } self._populate_message_properties(message) - handler = functools.partial(mgmt_handlers.peek_op, receiver=self) + handler = functools.partial(mgmt_handlers.peek_op, receiver=self, amqp_transport=self._amqp_transport) + start_time = time.time_ns() messages = self._mgmt_request_response_with_retry( REQUEST_RESPONSE_PEEK_OPERATION, message, handler, timeout=timeout ) links = get_receive_links(messages) - with receive_trace_context_manager(self, span_name=SPAN_NAME_PEEK, links=links): + with receive_trace_context_manager(self, span_name=SPAN_NAME_PEEK, links=links, start_time=start_time): return messages def complete_message(self, message: ServiceBusReceivedMessage) -> None: diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py index fdfc15a5b7d98..41a31b40bfbf1 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py @@ -9,10 +9,6 @@ import warnings from typing import Any, TYPE_CHECKING, Union, List, Optional, Mapping, cast -import uamqp -from uamqp import SendClient, types -from uamqp.authentication.common import AMQPAuth - from ._base_handler import BaseHandler from ._common import mgmt_handlers from ._common.message import ( @@ -20,15 +16,15 @@ ServiceBusMessageBatch, ) from .amqp import AmqpAnnotatedMessage -from .exceptions import ( - OperationTimeoutError, - _ServiceBusErrorPolicy, -) -from ._common.utils import ( - create_authentication, - transform_messages_if_needed, +from ._common.utils import create_authentication, transform_outbound_messages +from ._common.tracing import ( send_trace_context_manager, trace_message, + is_tracing_enabled, + get_span_links_from_batch, + get_span_link_from_message, + SPAN_NAME_SCHEDULE, + TraceAttributes, ) from ._common.constants import ( REQUEST_RESPONSE_CANCEL_SCHEDULED_MESSAGE_OPERATION, @@ -39,7 +35,7 @@ MGMT_REQUEST_MESSAGES, MGMT_REQUEST_MESSAGE_ID, MGMT_REQUEST_PARTITION_KEY, - SPAN_NAME_SCHEDULE, + MAX_MESSAGE_LENGTH_BYTES, ) if TYPE_CHECKING: @@ -48,6 +44,16 @@ AzureSasCredential, AzureNamedKeyCredential, ) + try: + # pylint:disable=unused-import + from uamqp import SendClient as uamqp_SendClientSync + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + except ImportError: + pass + + from ._transport._base import AmqpTransport + from ._pyamqp.authentication import JWTTokenAuth as pyamqp_JWTTokenAuth + from ._pyamqp.client import SendClient as pyamqp_SendClientSync MessageTypes = Union[ Mapping[str, Any], ServiceBusMessage, @@ -66,47 +72,42 @@ class SenderMixin(object): def _create_attribute(self, **kwargs): - self._auth_uri = "sb://{}/{}".format( - self.fully_qualified_namespace, self._entity_name - ) - self._entity_uri = "amqps://{}/{}".format( - self.fully_qualified_namespace, self._entity_name - ) - self._error_policy = _ServiceBusErrorPolicy( - max_retries=self._config.retry_total - ) - self._name = kwargs.get("client_identifier","SBSender-{}".format(uuid.uuid4())) + self._auth_uri = f"sb://{self.fully_qualified_namespace}/{self._entity_name}" + self._entity_uri = f"amqps://{self.fully_qualified_namespace}/{self._entity_name}" + # TODO: What's the retry overlap between servicebus and pyamqp? + self._error_policy = self._amqp_transport.create_retry_policy(self._config) + self._name = kwargs.get("client_identifier",f"SBSender-{uuid.uuid4()}") self._max_message_size_on_link = 0 self.entity_name = self._entity_name - def _set_msg_timeout(self, timeout=None, last_exception=None): - # pylint: disable=protected-access - if not timeout: - self._handler._msg_timeout = 0 - return - if timeout <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError(message="Send operation timed out") - _LOGGER.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = timeout * 1000 # type: ignore - @classmethod - def _build_schedule_request(cls, schedule_time_utc, send_span, *messages): + def _build_schedule_request(cls, schedule_time_utc, amqp_transport, tracing_attributes, *messages): request_body = {MGMT_REQUEST_MESSAGES: []} + trace_links = [] for message in messages: if not isinstance(message, ServiceBusMessage): raise ValueError( - "Scheduling batch messages only supports iterables containing " - "ServiceBusMessage Objects. Received instead: {}".format( - message.__class__.__name__ - ) + f"Scheduling batch messages only supports iterables containing " + f"ServiceBusMessage Objects. Received instead: {message.__class__.__name__}" ) message.scheduled_enqueue_time_utc = schedule_time_utc - message = transform_messages_if_needed(message, ServiceBusMessage) - trace_message(message, send_span) + message = transform_outbound_messages( + message, + ServiceBusMessage, + to_outgoing_amqp_message=amqp_transport.to_outgoing_amqp_message + ) + # pylint: disable=protected-access + message._message = trace_message( + message._message, + amqp_transport=amqp_transport, + additional_attributes=tracing_attributes + ) + + if is_tracing_enabled(): + link = get_span_link_from_message(message._message) + if link: + trace_links.append(link) + message_data = {} message_data[MGMT_REQUEST_MESSAGE_ID] = message.message_id if message.session_id: @@ -114,10 +115,10 @@ def _build_schedule_request(cls, schedule_time_utc, send_span, *messages): if message.partition_key: message_data[MGMT_REQUEST_PARTITION_KEY] = message.partition_key message_data[MGMT_REQUEST_MESSAGE] = bytearray( - message.message.encode_message() + amqp_transport.encode_message(message) ) request_body[MGMT_REQUEST_MESSAGES].append(message_data) - return request_body + return request_body, trace_links class ServiceBusSender(BaseHandler, SenderMixin): @@ -165,6 +166,7 @@ def __init__( topic_name: Optional[str] = None, **kwargs: Any ) -> None: + self._amqp_transport: "AmqpTransport" if kwargs.get("entity_name"): super(ServiceBusSender, self).__init__( fully_qualified_namespace=fully_qualified_namespace, @@ -193,10 +195,10 @@ def __init__( self._max_message_size_on_link = 0 self._create_attribute(**kwargs) self._connection = kwargs.get("connection") + self._handler: Union["pyamqp_SendClientSync", "uamqp_SendClientSync"] @classmethod - def _from_connection_string(cls, conn_str, **kwargs): - # type: (str, Any) -> ServiceBusSender + def _from_connection_string(cls, conn_str: str, **kwargs: Any) -> "ServiceBusSender": """Create a ServiceBusSender from a connection string. :param conn_str: The connection string of a Service Bus. @@ -232,17 +234,15 @@ def _from_connection_string(cls, conn_str, **kwargs): constructor_args = cls._convert_connection_string_to_kwargs(conn_str, **kwargs) return cls(**constructor_args) - def _create_handler(self, auth): - # type: (AMQPAuth) -> None - self._handler = SendClient( - self._entity_uri, + def _create_handler(self, auth: Union["uamqp_JWTTokenAuth", "pyamqp_JWTTokenAuth"]) -> None: + + self._handler = self._amqp_transport.create_send_client( + config=self._config, + target=self._entity_uri, auth=auth, - debug=self._config.logging_enable, properties=self._properties, - error_policy=self._error_policy, + retry_policy=self._error_policy, client_name=self._name, - keep_alive_interval=self._config.keep_alive, - encoding=self._config.encoding, ) def _open(self): @@ -260,8 +260,8 @@ def _open(self): time.sleep(0.05) self._running = True self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size - or uamqp.constants.MAX_MESSAGE_LENGTH_BYTES + self._amqp_transport.get_remote_max_message_size(self._handler) + or MAX_MESSAGE_LENGTH_BYTES ) except: self._close_handler() @@ -271,15 +271,9 @@ def _send( self, message: Union[ServiceBusMessage, ServiceBusMessageBatch], timeout: Optional[float] = None, - last_exception: Optional[Exception] = None + last_exception: Optional[Exception] = None # pylint: disable=unused-argument ) -> None: - self._open() - default_timeout = self._handler._msg_timeout # pylint: disable=protected-access - try: - self._set_msg_timeout(timeout, last_exception) - self._handler.send_message(message.message) - finally: # reset the timeout of the handler back to the default value - self._set_msg_timeout(default_timeout, None) + self._amqp_transport.send_messages(self, message, _LOGGER, timeout=timeout, last_exception=last_exception) def schedule_messages( self, @@ -315,23 +309,34 @@ def schedule_messages( # pylint: disable=protected-access self._check_live() - obj_messages = transform_messages_if_needed(messages, ServiceBusMessage) + obj_messages = transform_outbound_messages( + messages, ServiceBusMessage, to_outgoing_amqp_message=self._amqp_transport.to_outgoing_amqp_message + ) if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") - with send_trace_context_manager(span_name=SPAN_NAME_SCHEDULE) as send_span: - if isinstance(obj_messages, ServiceBusMessage): - request_body = self._build_schedule_request( - schedule_time_utc, send_span, obj_messages - ) - else: - if len(obj_messages) == 0: - return [] # No-op on empty list. - request_body = self._build_schedule_request( - schedule_time_utc, send_span, *obj_messages - ) - if send_span: - self._add_span_request_attributes(send_span) + tracing_attributes = { + TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE: self.fully_qualified_namespace, + TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE: self.entity_name, + } + if isinstance(obj_messages, ServiceBusMessage): + request_body, trace_links = self._build_schedule_request( + schedule_time_utc, + self._amqp_transport, + tracing_attributes, + obj_messages + ) + else: + if len(obj_messages) == 0: + return [] # No-op on empty list. + request_body, trace_links = self._build_schedule_request( + schedule_time_utc, + self._amqp_transport, + tracing_attributes, + *obj_messages + ) + + with send_trace_context_manager(self, span_name=SPAN_NAME_SCHEDULE, links=trace_links): return self._mgmt_request_response_with_retry( REQUEST_RESPONSE_SCHEDULE_MESSAGE_OPERATION, request_body, @@ -372,12 +377,12 @@ def cancel_scheduled_messages( if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") if isinstance(sequence_numbers, int): - numbers = [types.AMQPLong(sequence_numbers)] + numbers = [self._amqp_transport.AMQP_LONG_VALUE(sequence_numbers)] else: - numbers = [types.AMQPLong(s) for s in sequence_numbers] + numbers = [self._amqp_transport.AMQP_LONG_VALUE(s) for s in sequence_numbers] if len(numbers) == 0: return None # no-op on empty list. - request_body = {MGMT_REQUEST_SEQUENCE_NUMBERS: types.AMQPArray(numbers)} + request_body = {MGMT_REQUEST_SEQUENCE_NUMBERS: self._amqp_transport.AMQP_ARRAY_VALUE(numbers)} return self._mgmt_request_response_with_retry( REQUEST_RESPONSE_CANCEL_SCHEDULED_MESSAGE_OPERATION, request_body, @@ -428,29 +433,52 @@ def send_messages( if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") - with send_trace_context_manager() as send_span: - if isinstance(message, ServiceBusMessageBatch): - obj_message = message # type: MessageObjTypes + try: # Short circuit noop if an empty list or batch is provided. + if len(cast(Union[List, ServiceBusMessageBatch], message)) == 0: # pylint: disable=len-as-condition + return + except TypeError: # continue if ServiceBusMessage + pass + + obj_message: Union[ServiceBusMessage, ServiceBusMessageBatch] + + if isinstance(message, ServiceBusMessageBatch): + # If AmqpTransports are not the same, create batch with correct BatchMessage. + if self._amqp_transport.KIND != message._amqp_transport.KIND: # pylint: disable=protected-access + # pylint: disable=protected-access + batch = self.create_message_batch() + batch._from_list(message._messages) # type: ignore + obj_message = batch else: - obj_message = transform_messages_if_needed( # type: ignore - message, ServiceBusMessage + obj_message = message + else: + obj_message = transform_outbound_messages( # type: ignore + message, ServiceBusMessage, self._amqp_transport.to_outgoing_amqp_message + ) + try: + batch = self.create_message_batch() + batch._from_list(obj_message) # type: ignore # pylint: disable=protected-access + obj_message = batch + except TypeError: # Message was not a list or generator. Do needed tracing. + # pylint: disable=protected-access + obj_message._message = trace_message( + obj_message._message, + amqp_transport=self._amqp_transport, + additional_attributes={ + TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE: self.fully_qualified_namespace, + TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE: self.entity_name, + } ) - try: - batch = self.create_message_batch() - batch._from_list(obj_message, send_span) # type: ignore # pylint: disable=protected-access - obj_message = batch - except TypeError: # Message was not a list or generator. Do needed tracing. - trace_message(cast(ServiceBusMessage, obj_message), send_span) - - if ( - isinstance(obj_message, ServiceBusMessageBatch) - and len(obj_message) == 0 - ): # pylint: disable=len-as-condition - return # Short circuit noop if an empty list or batch is provided. - - if send_span: - self._add_span_request_attributes(send_span) + trace_links = [] + if is_tracing_enabled(): + if isinstance(obj_message, ServiceBusMessageBatch): + trace_links = get_span_links_from_batch(obj_message) + else: + link = get_span_link_from_message(obj_message._message) # pylint: disable=protected-access + if link: + trace_links.append(link) + + with send_trace_context_manager(self, links=trace_links): self._do_retryable_operation( self._send, message=obj_message, @@ -486,13 +514,17 @@ def create_message_batch( if max_size_in_bytes and max_size_in_bytes > self._max_message_size_on_link: raise ValueError( - "Max message size: {} is too large, acceptable max batch size is: {} bytes.".format( - max_size_in_bytes, self._max_message_size_on_link - ) + f"Max message size: {max_size_in_bytes} is too large, " + f"acceptable max batch size is: {self._max_message_size_on_link} bytes." ) return ServiceBusMessageBatch( - max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link) + max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), + amqp_transport=self._amqp_transport, + tracing_attributes = { + TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE: self.fully_qualified_namespace, + TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE: self.entity_name, + } ) @property diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session.py index bc30411014da4..f095c85f1ab1f 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session.py @@ -29,19 +29,21 @@ class BaseSession(object): - def __init__(self, session_id, receiver): - # type: (str, Union[ServiceBusReceiver, ServiceBusReceiverAsync]) -> None + def __init__( + self, + session_id: str, + receiver: Union["ServiceBusReceiver", "ServiceBusReceiverAsync"] + ) -> None: self._session_id = session_id self._receiver = receiver self._encoding = "UTF-8" self._session_start = None - self._locked_until_utc = None # type: Optional[datetime.datetime] + self._locked_until_utc: Optional[datetime.datetime] = None self._lock_lost = False self.auto_renew_error = None @property - def _lock_expired(self): - # type: () -> bool + def _lock_expired(self) -> bool: """Whether the receivers lock on a particular session has expired. :rtype: bool @@ -49,8 +51,7 @@ def _lock_expired(self): return bool(self._locked_until_utc and self._locked_until_utc <= utc_now()) @property - def session_id(self): - # type: () -> str + def session_id(self) -> str: """ Session id of the current session. @@ -59,8 +60,7 @@ def session_id(self): return self._session_id @property - def locked_until_utc(self): - # type: () -> Optional[datetime.datetime] + def locked_until_utc(self) -> Optional[datetime.datetime]: """The time at which this session's lock will expire. :rtype: datetime.datetime diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/__init__.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/__init__.py new file mode 100644 index 0000000000000..34913fb394d7a --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_base.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_base.py new file mode 100644 index 0000000000000..1b9025f4edd41 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_base.py @@ -0,0 +1,333 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import annotations +from typing import Union, TYPE_CHECKING, Dict, Any, Callable +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + try: + from uamqp import types as uamqp_types + except ImportError: + pass + +class AmqpTransport(ABC): # pylint: disable=too-many-public-methods + """ + Abstract class that defines a set of common methods needed by sender and receiver. + """ + KIND: str + + # define constants + MAX_FRAME_SIZE_BYTES: int + MAX_MESSAGE_LENGTH_BYTES: int + TIMEOUT_FACTOR: int + TRANSPORT_IDENTIFIER: str + + ServiceBusToAMQPReceiveModeMap: Dict[str, Any] + + # define symbols + PRODUCT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + VERSION_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + FRAMEWORK_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PLATFORM_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + USER_AGENT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PROP_PARTITION_KEY_AMQP_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + AMQP_LONG_VALUE: Callable + AMQP_ARRAY_VALUE: Callable + AMQP_UINT_VALUE: Callable + + @staticmethod + @abstractmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def build_batch_message(data): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ + + @staticmethod + @abstractmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def encode_message(message): + """ + Encodes the outgoing uamqp/pyamqp.Message of the message. + :param ServiceBusMessage message: Message. + :rtype: bytes + """ + + @staticmethod + @abstractmethod + def update_message_app_properties(message, key, value): + """ + Adds the given key/value to the application properties of the message. + :param uamqp.Message or pyamqp.Message message: Message. + :param str key: Key to set in application properties. + :param str Value: Value to set for key in application properties. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def get_message_encoded_size(message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message or pyamqp.Message message: Message to get encoded size of. + :rtype: int + """ + + @staticmethod + @abstractmethod + def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @staticmethod + @abstractmethod + def get_handler_link_name(handler): + """ + Returns link name. + :param AMQPClient handler: Client to get name of link from. + :rtype: str + """ + + @staticmethod + @abstractmethod + def create_retry_policy(config, *, is_session=False): + """ + Creates the error retry policy. + :param ~azure.servicebus._configuration.Configuration config: Configuration. + :keyword bool is_session: Is session enabled. + """ + + @staticmethod + @abstractmethod + def create_connection(host, auth, network_trace, **kwargs): + """ + Creates and returns the uamqp/pyamqp Connection object. + :param str host: The hostname used by uamqp/pyamqp. + :param JWTTokenAuth auth: The auth used by uamqp/pyamqp. + :param bool network_trace: Debug setting. + """ + + @staticmethod + @abstractmethod + def close_connection(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def create_send_client(config, **kwargs): + """ + Creates and returns the uamqp SendClient. + :param ~azure.servicebus._common._configuration.Configuration config: + The configuration. + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @staticmethod + @abstractmethod + def send_messages( + sender, message, logger, timeout, last_exception + ): + """ + Handles sending of service bus messages. + :param ~azure.servicebus.ServiceBusSender sender: The sender with handler + to send messages. + :param message: ServiceBusMessage with uamqp.Message to be sent. + :paramtype message: ~azure.servicebus.ServiceBusMessage or ~azure.servicebus.ServiceBusMessageBatch + :param int timeout: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + @staticmethod + @abstractmethod + def add_batch(sb_message_batch, outgoing_sb_message): + """ + Add ServiceBusMessage to the data body of the BatchMessage. + :param sb_message_batch: ServiceBusMessageBatch to add data to. + :param outgoing_sb_message: Transformed ServiceBusMessage for sending. + :rtype: None + """ + + @staticmethod + @abstractmethod + def create_source(source, session_filter): + """ + Creates and returns the Source. + + :param Source source: Required. + :param str or None session_id: Required. + """ + + @staticmethod + @abstractmethod + def create_receive_client(receiver, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.servicebus._common._configuration.Configuration config: + The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword timeout: Required. + """ + + @staticmethod + @abstractmethod + def iter_contextual_wrapper( + receiver, max_wait_time=None + ): + """The purpose of this wrapper is to allow both state restoration (for multiple concurrent iteration) + and per-iter argument passing that requires the former.""" + + @staticmethod + @abstractmethod + def iter_next( + receiver, wait_time=None + ): + """ + Used to iterate through received messages. + """ + + @staticmethod + @abstractmethod + def build_received_message(receiver, message_type, received): + """ + Build ServiceBusReceivedMessage. + """ + + @staticmethod + @abstractmethod + def get_current_time(handler): + """ + Gets the current time. + """ + + @staticmethod + @abstractmethod + def reset_link_credit( + handler, link_credit + ): + """ + Resets the link credit on the link. + """ + + @staticmethod + @abstractmethod + def settle_message_via_receiver_link( + handler, + message, + settle_operation, + dead_letter_reason=None, + dead_letter_error_description=None, + ) -> None: + """ + Settles message. + """ + + @staticmethod + @abstractmethod + def parse_received_message(message, message_type, **kwargs): + """ + Parses peek/deferred op messages into ServiceBusReceivedMessage. + :param Message message: Message to parse. + :param ServiceBusReceivedMessage message_type: Parse messages to return. + :keyword ServiceBusReceiver receiver: Required. + :keyword bool is_peeked_message: Optional. For peeked messages. + :keyword bool is_deferred_message: Optional. For deferred messages. + :keyword ServiceBusReceiveMode receive_mode: Optional. + """ + + @staticmethod + @abstractmethod + def get_message_value(message): + """Get body of type value from message.""" + + @staticmethod + @abstractmethod + def create_token_auth( + auth_uri, get_token, token_type, config, **kwargs + ): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.servicebus._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @staticmethod + @abstractmethod + def create_mgmt_msg( + message, application_properties, config, reply_to, **kwargs + ): + """ + :param message: The message to send in the management request. + :paramtype message: Any + :param Dict[bytes, str] application_properties: App props. + :param ~azure.servicebus._common._configuration.Configuration config: Configuration. + :param str reply_to: Reply to. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def mgmt_client_request( + mgmt_client, mgmt_msg, *, operation, operation_type, node, timeout, callback + ): + """ + Send mgmt request and return result of callback. + :param AMQPClient mgmt_client: Client to send request with. + :param Message mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword bytes operation_type: Op type. + :keyword bytes node: Mgmt target. + :keyword int timeout: Timeout. + :keyword Callable callback: Callback to process request response. + """ diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py new file mode 100644 index 0000000000000..9743a250d5103 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py @@ -0,0 +1,954 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import functools +import time +import datetime +from datetime import timezone +from typing import Optional, Tuple, cast, List, TYPE_CHECKING, Any, Callable, Dict, Union, Iterator, Type + +from .._pyamqp import ( + utils, + SendClient, + constants, + ReceiveClient, + __version__, +) +from .._pyamqp.error import ( + ErrorCondition, + AMQPException, + AMQPError, + RetryPolicy, + AMQPConnectionError, + AuthenticationException, + MessageException, +) +from .._pyamqp.utils import amqp_long_value, amqp_array_value, amqp_string_value, amqp_uint_value +from .._pyamqp._encode import encode_payload +from .._pyamqp._decode import decode_payload +from .._pyamqp.message import Message, BatchMessage, Header, Properties +from .._pyamqp.authentication import JWTTokenAuth +from .._pyamqp.endpoints import Source +from .._pyamqp._connection import Connection, _CLOSING_STATES + +from ._base import AmqpTransport +from .._common.utils import utc_from_timestamp, utc_now +from .._common.tracing import get_receive_links, receive_trace_context_manager +from .._common.constants import ( + PYAMQP_LIBRARY, + DATETIMEOFFSET_EPOCH, + RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION, + RECEIVER_LINK_DEAD_LETTER_REASON, + DEADLETTERNAME, + MAX_ABSOLUTE_EXPIRY_TIME, + MAX_DURATION_VALUE, + MAX_MESSAGE_LENGTH_BYTES, + MESSAGE_COMPLETE, + MESSAGE_ABANDON, + MESSAGE_DEFER, + MESSAGE_DEAD_LETTER, + SESSION_FILTER, + SESSION_LOCKED_UNTIL, + _X_OPT_ENQUEUED_TIME, + _X_OPT_LOCKED_UNTIL, + ERROR_CODE_SESSION_LOCK_LOST, + ERROR_CODE_MESSAGE_LOCK_LOST, + ERROR_CODE_MESSAGE_NOT_FOUND, + ERROR_CODE_TIMEOUT, + ERROR_CODE_AUTH_FAILED, + ERROR_CODE_SESSION_CANNOT_BE_LOCKED, + ERROR_CODE_SERVER_BUSY, + ERROR_CODE_ARGUMENT_ERROR, + ERROR_CODE_OUT_OF_RANGE, + ERROR_CODE_ENTITY_DISABLED, + ERROR_CODE_ENTITY_ALREADY_EXISTS, + ERROR_CODE_PRECONDITION_FAILED, + ServiceBusReceiveMode, +) + +from ..exceptions import ( + MessageSizeExceededError, + ServiceBusQuotaExceededError, + ServiceBusAuthorizationError, + ServiceBusError, + ServiceBusConnectionError, + ServiceBusAuthenticationError, + ServiceBusCommunicationError, + MessageLockLostError, + MessageNotFoundError, + MessagingEntityDisabledError, + MessagingEntityNotFoundError, + MessagingEntityAlreadyExistsError, + ServiceBusServerBusyError, + SessionCannotBeLockedError, + SessionLockLostError, + OperationTimeoutError +) + +if TYPE_CHECKING: + from logging import Logger + from ..amqp import AmqpAnnotatedMessage, AmqpMessageHeader, AmqpMessageProperties + from .._servicebus_receiver import ServiceBusReceiver + from .._servicebus_sender import ServiceBusSender + from .._common.message import ServiceBusReceivedMessage, ServiceBusMessage, ServiceBusMessageBatch + from .._common._configuration import Configuration + from .._pyamqp.performatives import AttachFrame, TransferFrame + from .._pyamqp.client import AMQPClient + + +class _ServiceBusErrorPolicy(RetryPolicy): + + no_retry = RetryPolicy.no_retry + cast(List[ErrorCondition], [ + ERROR_CODE_SESSION_LOCK_LOST, + ERROR_CODE_MESSAGE_LOCK_LOST, + ERROR_CODE_OUT_OF_RANGE, + ERROR_CODE_ARGUMENT_ERROR, + ERROR_CODE_PRECONDITION_FAILED, + ]) + + def __init__(self, is_session=False, **kwargs): + self._is_session = is_session + custom_condition_backoff = { + b"com.microsoft:server-busy": 4, + b"com.microsoft:timeout": 2, + b"com.microsoft:container-close": 4 + } + super(_ServiceBusErrorPolicy, self).__init__( + custom_condition_backoff=custom_condition_backoff, + **kwargs + ) + + def is_retryable(self, error): + if self._is_session: + return False + return super().is_retryable(error) + +_LONG_ANNOTATIONS = ( + _X_OPT_ENQUEUED_TIME, + _X_OPT_LOCKED_UNTIL +) + +_ERROR_CODE_TO_ERROR_MAPPING = { + ErrorCondition.LinkMessageSizeExceeded: MessageSizeExceededError, + ErrorCondition.ResourceLimitExceeded: ServiceBusQuotaExceededError, + ErrorCondition.UnauthorizedAccess: ServiceBusAuthorizationError, + ErrorCondition.NotImplemented: ServiceBusError, + ErrorCondition.NotAllowed: ServiceBusError, + ErrorCondition.LinkDetachForced: ServiceBusConnectionError, + ERROR_CODE_MESSAGE_LOCK_LOST: MessageLockLostError, + ERROR_CODE_MESSAGE_NOT_FOUND: MessageNotFoundError, + ERROR_CODE_AUTH_FAILED: ServiceBusAuthorizationError, + ERROR_CODE_ENTITY_DISABLED: MessagingEntityDisabledError, + ERROR_CODE_ENTITY_ALREADY_EXISTS: MessagingEntityAlreadyExistsError, + ERROR_CODE_SERVER_BUSY: ServiceBusServerBusyError, + ERROR_CODE_SESSION_CANNOT_BE_LOCKED: SessionCannotBeLockedError, + ERROR_CODE_SESSION_LOCK_LOST: SessionLockLostError, + ERROR_CODE_ARGUMENT_ERROR: ServiceBusError, + ERROR_CODE_OUT_OF_RANGE: ServiceBusError, + ERROR_CODE_TIMEOUT: OperationTimeoutError, +} + +class PyamqpTransport(AmqpTransport): # pylint: disable=too-many-public-methods + """ + Class which defines uamqp-based methods used by the sender and receiver. + """ + + KIND = "pyamqp" + + # define constants + MAX_FRAME_SIZE_BYTES = constants.MAX_FRAME_SIZE_BYTES + MAX_MESSAGE_LENGTH_BYTES = MAX_MESSAGE_LENGTH_BYTES + TIMEOUT_FACTOR = 1 + CONNECTION_CLOSING_STATES: Tuple = _CLOSING_STATES + TRANSPORT_IDENTIFIER = f"{PYAMQP_LIBRARY}/{__version__}" + + # To enable extensible string enums for the public facing parameter, and translate to the "real" uamqp constants. + ServiceBusToAMQPReceiveModeMap = { + ServiceBusReceiveMode.PEEK_LOCK: constants.ReceiverSettleMode.Second, + ServiceBusReceiveMode.RECEIVE_AND_DELETE: constants.ReceiverSettleMode.First, + } + + # define symbols + PRODUCT_SYMBOL = "product" + VERSION_SYMBOL = "version" + FRAMEWORK_SYMBOL = "framework" + PLATFORM_SYMBOL = "platform" + USER_AGENT_SYMBOL = "user-agent" + #ERROR_CONDITIONS = [condition.value for condition in ErrorCondition] + + # amqp value types + AMQP_LONG_VALUE: Callable = amqp_long_value + AMQP_ARRAY_VALUE: Callable = amqp_array_value + AMQP_UINT_VALUE: Callable = amqp_uint_value + + # errors + TIMEOUT_ERROR = TimeoutError + + @staticmethod + def build_message(**kwargs: Any) -> "Message": + """ + Creates a pyamqp.Message with given arguments. + :rtype: pyamqp.Message + """ + return Message(**kwargs) + + @staticmethod + def build_batch_message(data: List) -> List[List]: + """ + Creates a List representing a pyamqp.BatchMessage with given arguments. + :rtype: List[List] + """ + message = cast(List, [None] * 9) + message[5] = data + return message + + @staticmethod + def get_message_delivery_tag( + _, frame: "TransferFrame" + ) -> str: # pylint: disable=unused-argument + """ + Gets delivery tag of a Message. + :param message: Message to get delivery_tag from for uamqp.Message. + :param frame: Frame to get delivery_tag from for pyamqp.Message. + :rtype: str + """ + return frame[2] if frame else None + + @staticmethod + def get_message_delivery_id( + _, frame: "TransferFrame" + ) -> str: # pylint: disable=unused-argument + """ + Gets delivery id of a Message. + :param message: Message to get delivery_id from for uamqp.Message. + :param frame: Message to get delivery_id from for pyamqp.Message. + :rtype: str + """ + return frame[1] if frame else None + + @staticmethod + def to_outgoing_amqp_message(annotated_message: "AmqpAnnotatedMessage") -> "Message": + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: pyamqp.Message + """ + message_header = None + ttl_set = False + header_vals = annotated_message.header.values() if annotated_message.header else None + # If header and non-None header values, create outgoing header. + if header_vals and header_vals.count(None) != len(header_vals): + annotated_message.header = cast("AmqpMessageHeader", annotated_message.header) + message_header = Header( + delivery_count=annotated_message.header.delivery_count, + ttl=annotated_message.header.time_to_live, + first_acquirer=annotated_message.header.first_acquirer, + durable=annotated_message.header.durable, + priority=annotated_message.header.priority, + ) + if annotated_message.header.time_to_live and annotated_message.header.time_to_live != MAX_DURATION_VALUE: + ttl_set = True + creation_time_from_ttl = int( + time.mktime(datetime.datetime.now(timezone.utc).timetuple()) * 1000 # TODO: should this be * 1? + ) + absolute_expiry_time_from_ttl = int(min( + MAX_ABSOLUTE_EXPIRY_TIME, + creation_time_from_ttl + annotated_message.header.time_to_live + )) + + message_properties = None + properties_vals = annotated_message.properties.values() if annotated_message.properties else None + # If properties and non-None properties values, create outgoing properties. + if properties_vals and properties_vals.count(None) != len(properties_vals): + annotated_message.properties = cast("AmqpMessageProperties", annotated_message.properties) + creation_time = None + absolute_expiry_time = None + if ttl_set: + creation_time = creation_time_from_ttl + absolute_expiry_time = absolute_expiry_time_from_ttl + else: + if annotated_message.properties.creation_time: + creation_time = int(annotated_message.properties.creation_time) + if annotated_message.properties.absolute_expiry_time: + absolute_expiry_time = int(annotated_message.properties.absolute_expiry_time) + + message_properties = Properties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=creation_time, + absolute_expiry_time=absolute_expiry_time, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + ) + elif ttl_set: + message_properties = Properties( # type: ignore[call-arg] + creation_time=creation_time_from_ttl if ttl_set else None, + absolute_expiry_time=absolute_expiry_time_from_ttl if ttl_set else None, + ) + annotations = None + if annotated_message.annotations: + # TODO: Investigate how we originally encoded annotations. + annotations = dict(annotated_message.annotations) + for key in _LONG_ANNOTATIONS: + if key in annotated_message.annotations: + annotations[key] = amqp_long_value(annotated_message.annotations[key]) + + if annotated_message.application_properties: + for app_key, app_val in annotated_message.application_properties.items(): + # This is being done to bring parity with uamqp. uamqp will decode bytes to str in + # application properties and this will match that behavior + if isinstance(app_val, bytes): + annotated_message.application_properties[app_key] = app_val.decode("utf-8") + + message_dict = { + "header": message_header, + "properties": message_properties, + "application_properties": annotated_message.application_properties, + "message_annotations": annotations, + "delivery_annotations": annotated_message.delivery_annotations, + "data": annotated_message._data_body, # pylint: disable=protected-access + "sequence": annotated_message._sequence_body, # pylint: disable=protected-access + "value": annotated_message._value_body, # pylint: disable=protected-access + "footer": annotated_message.footer, + } + + return Message(**message_dict) + + @staticmethod + def encode_message(message: "ServiceBusMessage") -> bytes: + """ + Encodes the outgoing pyamqp.Message of the message. + :param ServiceBusMessage message: Message. + :rtype: bytes + """ + output = bytearray() + return encode_payload(output, message._message) # pylint: disable=protected-access + + @staticmethod + def update_message_app_properties( + message: "Message", + key: str, + value: str + ) -> "Message": + """ + Adds the given key/value to the application properties of the message. + :param pyamqp.Message message: Message. + :param str key: Key to set in application properties. + :param str Value: Value to set for key in application properties. + :rtype: pyamqp.Message + """ + if not message.application_properties: + message = message._replace(application_properties={}) + message.application_properties.setdefault(key, value) + return message + + @staticmethod + def get_batch_message_encoded_size(message: List[bytes]) -> int: + """ + Gets the batch message encoded size given an underlying Message. + :param List message: Message to get encoded size of. + :rtype: int + """ + return utils.get_message_encoded_size(BatchMessage(*message)) + + @staticmethod + def get_message_encoded_size(message: "Message") -> int: + """ + Gets the message encoded size given an underlying Message. + :param pyamqp.Message: Message to get encoded size of. + :rtype: int + """ + return utils.get_message_encoded_size(message) + + @staticmethod + def get_remote_max_message_size(handler: "AMQPClient") -> int: + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + return handler._link.remote_max_message_size # pylint: disable=protected-access + + @staticmethod + def get_handler_link_name(handler: "AMQPClient") -> str: + """ + Returns link name. + :param AMQPClient handler: Client to get name of link from. + :rtype: str + """ + # pylint: disable=protected-access + return handler._link.name + + @staticmethod + def create_retry_policy( + config: "Configuration", *, is_session: bool = False + ) -> "_ServiceBusErrorPolicy": + """ + Creates the error retry policy. + :param Configuration config: Configuration. + :keyword bool is_session: Is session enabled. + """ + # TODO: What's the retry overlap between servicebus and pyamqp? + return _ServiceBusErrorPolicy( + is_session=is_session, + retry_total=config.retry_total, + retry_backoff_factor=config.retry_backoff_factor, + retry_backoff_max=config.retry_backoff_max, + retry_mode=config.retry_mode, + #no_retry_condition=NO_RETRY_ERRORS, + #custom_condition_backoff=CUSTOM_CONDITION_BACKOFF, + ) + + @staticmethod + def create_connection( + host: str, auth: "JWTTokenAuth", network_trace: bool, **kwargs: Any + ) -> "Connection": + """ + Creates and returns the pyamqp Connection object. + :param str host: The hostname used by pyamqp. + :param JWTTokenAuth auth: The auth used by pyamqp. + :param bool network_trace: Debug setting. + """ + return Connection( + endpoint=host, + sasl_credential=auth.sasl, + network_trace=network_trace, + **kwargs + ) + + @staticmethod + def close_connection(connection: "Connection") -> None: + """ + Closes existing connection. + :param Connection connection: uamqp or pyamqp Connection. + """ + connection.close() + + @staticmethod + def create_send_client( + config: "Configuration", **kwargs: Any + ) ->"SendClient": + """ + Creates and returns the pyamqp SendClient. + :keyword ~azure.servicebus._configuration.Configuration config: The configuration. Required. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword properties: Required. + """ + + target = kwargs.pop("target") + return SendClient( + config.hostname, + target, + network_trace=config.logging_enable, + keep_alive_interval=config.keep_alive, + custom_endpoint_address=config.custom_endpoint_address, + connection_verify=config.connection_verify, + transport_type=config.transport_type, + http_proxy=config.http_proxy, + **kwargs, + ) + + @staticmethod + def send_messages( + sender: "ServiceBusSender", + message: Union["ServiceBusMessage", "ServiceBusMessageBatch"], + logger: "Logger", + timeout: int, + last_exception: Optional[Exception] + ) -> None: # pylint: disable=unused-argument + """ + Handles sending of service bus messages. + :param ~azure.servicebus._servicebus_sender.ServiceBusSender sender: The sender with handler + to send messages. + :param Message message: Message to send. + :param logger: Logger. + :param int timeout: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + """ + # pylint: disable=protected-access + sender._open() + try: + if isinstance(message._message, list): # BatchMessage + sender._handler.send_message(BatchMessage(*message._message), timeout=timeout) # pylint:disable=protected-access + else: # Message + sender._handler.send_message(message._message, timeout=timeout) # pylint:disable=protected-access + except TimeoutError: + raise OperationTimeoutError(message="Send operation timed out") + except MessageException as e: + raise PyamqpTransport.create_servicebus_exception(logger, e) + + @staticmethod + def add_batch( + sb_message_batch: "ServiceBusMessageBatch", outgoing_sb_message: "ServiceBusMessage" + ) -> None: # pylint: disable=unused-argument + """ + Add ServiceBusMessage to the data body of the BatchMessage. + :param sb_message_batch: ServiceBusMessageBatch to add data to. + :param outgoing_sb_message: Transformed ServiceBusMessage for sending. + :rtype: None + """ + # pylint: disable=protected-access + utils.add_batch( + sb_message_batch._message, outgoing_sb_message._message + ) + + @staticmethod + def create_source(source: "Source", session_filter: Optional[str]) -> "Source": + """ + Creates and returns the Source. + + :param Source source: Required. + :param str or None session_id: Required. + """ + filter_map = {SESSION_FILTER: session_filter} + source = Source(address=source, filters=filter_map) # type: ignore[call-arg] + return source + + @staticmethod + def create_receive_client( + receiver: "ServiceBusReceiver", **kwargs: "Any" + ) -> "ReceiveClient": + """ + Creates and returns the receive client. + :param Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword timeout: Required. + """ + config = receiver._config # pylint: disable=protected-access + source = kwargs.pop("source") + receive_mode = kwargs.pop("receive_mode") + return ReceiveClient( + config.hostname, + source, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_address=config.custom_endpoint_address, + connection_verify=config.connection_verify, + receive_settle_mode=PyamqpTransport.ServiceBusToAMQPReceiveModeMap[receive_mode], + send_settle_mode=constants.SenderSettleMode.Settled + if receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE + else constants.SenderSettleMode.Unsettled, + on_attach=functools.partial( + PyamqpTransport.on_attach, + receiver + ), + **kwargs + ) + + # TODO: ask why this is a callable. + @staticmethod + def on_attach( + receiver: "ServiceBusReceiver", + attach_frame: "AttachFrame" + ) -> None: + """ + Receiver on_attach callback. + + :param ServiceBusReceiver receiver: Required. + :param AttachFrame attach_frame: Required. + """ + # pylint: disable=protected-access, unused-argument + if receiver._session and attach_frame.source.address.decode() == receiver._entity_uri: + # This has to live on the session object so that autorenew has access to it. + receiver._session._session_start = utc_now() + expiry_in_seconds = attach_frame.properties.get(SESSION_LOCKED_UNTIL) + if expiry_in_seconds: + expiry_in_seconds = ( + expiry_in_seconds - DATETIMEOFFSET_EPOCH + ) / 10000000 + receiver._session._locked_until_utc = utc_from_timestamp(expiry_in_seconds) + session_filter = attach_frame.source.filters[SESSION_FILTER] + receiver._session_id = session_filter.decode(receiver._config.encoding) + receiver._session._session_id = receiver._session_id + + @staticmethod + def iter_contextual_wrapper( + receiver: "ServiceBusReceiver", max_wait_time: Optional[int] = None + ) -> Iterator["ServiceBusReceivedMessage"]: + """The purpose of this wrapper is to allow both state restoration (for multiple concurrent iteration) + and per-iter argument passing that requires the former.""" + while True: + try: + # pylint: disable=protected-access + message = receiver._inner_next(wait_time=max_wait_time) + links = get_receive_links(message) + with receive_trace_context_manager(receiver, links=links): + yield message + except StopIteration: + break + + @staticmethod + def iter_next( + receiver: "ServiceBusReceiver", wait_time: Optional[int] = None + ) -> "ServiceBusReceivedMessage": + """ + Used to iterate through received messages. + """ + # pylint: disable=protected-access + try: + receiver._receive_context.set() + receiver._open() + if not receiver._message_iter or wait_time: + receiver._message_iter = receiver._handler.receive_messages_iter(timeout=wait_time) + pyamqp_message = next( + cast(Iterator["Message"], receiver._message_iter) + ) + message = receiver._build_received_message(pyamqp_message) + if ( + receiver._auto_lock_renewer + and not receiver._session + and receiver._receive_mode != ServiceBusReceiveMode.RECEIVE_AND_DELETE + ): + receiver._auto_lock_renewer.register(receiver, message) + return message + finally: + receiver._receive_context.clear() + + @staticmethod + def enhanced_message_received( + receiver: "ServiceBusReceiver", + frame: "AttachFrame", + message: "Message" + ) -> None: + """ + Receiver enhanced_message_received callback. + """ + # pylint: disable=protected-access + receiver._handler._last_activity_timestamp = time.time() + if receiver._receive_context.is_set(): + receiver._handler._received_messages.put((frame, message)) + else: + receiver._handler.settle_messages(frame[1], 'released') + + @staticmethod + def build_received_message( + receiver: "ServiceBusReceiver", + message_type: Type["ServiceBusReceivedMessage"], + received: "Message" + ) -> "ServiceBusReceivedMessage": + """ + Build ServiceBusReceivedMessage. + """ + # pylint: disable=protected-access + message = message_type( + message=received[1], + receive_mode=receiver._receive_mode, + receiver=receiver, + frame=received[0], + amqp_transport=receiver._amqp_transport + ) + receiver._last_received_sequenced_number = message.sequence_number + return message + + @staticmethod + def get_current_time( + handler: "ReceiveClient" + ) -> float: # pylint: disable=unused-argument + """ + Gets the current time. + """ + return time.time() + + @staticmethod + def reset_link_credit( + handler: "ReceiveClient", link_credit: int + ) -> None: + """ + Resets the link credit on the link. + :param ReceiveClient handler: Client with link to reset link credit. + :param int link_credit: Link credit needed. + :rtype: None + """ + handler._link.flow(link_credit=link_credit) # pylint: disable=protected-access + + @staticmethod + def settle_message_via_receiver_link( + handler: ReceiveClient, + message: "ServiceBusReceivedMessage", + settle_operation: str, + dead_letter_reason: Optional[str] = None, + dead_letter_error_description: Optional[str] = None, + ) -> None: + # pylint: disable=protected-access + if settle_operation == MESSAGE_COMPLETE: + return handler.settle_messages(message._delivery_id, 'accepted') + if settle_operation == MESSAGE_ABANDON: + return handler.settle_messages( + message._delivery_id, + 'modified', + delivery_failed=True, + undeliverable_here=False + ) + if settle_operation == MESSAGE_DEAD_LETTER: + return handler.settle_messages( + message._delivery_id, + 'rejected', + error=AMQPError( + condition=DEADLETTERNAME, + description=dead_letter_error_description, + info={ + RECEIVER_LINK_DEAD_LETTER_REASON: dead_letter_reason, + RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION: dead_letter_error_description, + } + ) + ) + if settle_operation == MESSAGE_DEFER: + return handler.settle_messages( + message._delivery_id, + 'modified', + delivery_failed=True, + undeliverable_here=True + ) + raise ValueError( + f"Unsupported settle operation type: {settle_operation}" + ) + + @staticmethod + def parse_received_message( + message: "Message", + message_type: Type["ServiceBusReceivedMessage"], + **kwargs: Any + ) -> List["ServiceBusReceivedMessage"]: + """ + Parses peek/deferred op messages into ServiceBusReceivedMessage. + :param Message message: Message to parse. + :param ServiceBusReceivedMessage message_type: Parse messages to return. + :keyword ServiceBusReceiver receiver: Required. + :keyword bool is_peeked_message: Optional. For peeked messages. + :keyword bool is_deferred_message: Optional. For deferred messages. + :keyword ServiceBusReceiveMode receive_mode: Optional. + """ + parsed = [] + for m in message.value[b"messages"]: + wrapped = decode_payload(memoryview(m[b"message"])) + parsed.append( + message_type( + wrapped, **kwargs + ) + ) + return parsed + + @staticmethod + def get_message_value(message: "Message") -> Any: + """Get body of type value from message.""" + return message.value + + @staticmethod + def create_token_auth( + auth_uri: str, + get_token: Callable, + token_type: bytes, + config: "Configuration", + **kwargs: Any + ) -> "JWTTokenAuth": + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, then pass 300 to refresh_window. + """ + # TODO: figure out why we're passing all these args to pyamqp JWTTokenAuth, which aren't being used + update_token = kwargs.pop("update_token") # pylint: disable=unused-variable + if update_token: + # update_token not actually needed by pyamqp + # just using to detect which kwargs to pass + return JWTTokenAuth(auth_uri, auth_uri, get_token) + return JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + ) + + @staticmethod + def create_mgmt_msg( + message: "Message", + application_properties: Dict[str, Any], + config: "Configuration", + reply_to: str, + **kwargs: Any + ) -> "Message": # pylint:disable=unused-argument + """ + :param message: The message to send in the management request. + :paramtype message: Any + :param Dict[bytes, str] application_properties: App props. + :param ~azure.servicebus._common._configuration.Configuration config: Configuration. + :param str reply_to: Reply to. + :rtype: pyamqp.Message + """ + return Message( # type: ignore # TODO: fix mypy error + value=message, + properties=Properties( + reply_to=reply_to, + **kwargs + ), + application_properties=application_properties, + ) + + @staticmethod + def mgmt_client_request( + mgmt_client: "AMQPClient", + mgmt_msg: "Message", + *, + operation: bytes, + operation_type: bytes, + node: bytes, + timeout: int, + callback: Callable + ) -> "ServiceBusReceivedMessage": + """ + Send mgmt request. + :param AMQPClient mgmt_client: Client to send request with. + :param Message mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword bytes operation_type: Op type. + :keyword bytes node: Mgmt target. + :keyword int timeout: Timeout. + :keyword Callable callback: Callback to process request response. + """ + status, description, response = mgmt_client.mgmt_request( + mgmt_msg, + operation=amqp_string_value(operation.decode("UTF-8")), + operation_type=amqp_string_value(operation_type), + node=node, + timeout=timeout, # TODO: check if this should be seconds * 1000 if timeout else None, + ) + return callback(status, response, description, amqp_transport=PyamqpTransport) + + @staticmethod + def _handle_amqp_exception_with_condition( + logger: "Logger", + condition: Optional["ErrorCondition"], + description: str, + exception: Optional["AMQPException"] = None, + status_code: Optional[str] = None, + *, + custom_endpoint_address: Optional[str] = None + ) -> "ServiceBusError": + # handling AMQP Errors that have the condition field or the mgmt handler + logger.info( + "AMQP error occurred: (%r), condition: (%r), description: (%r).", + exception, + condition, + description, + ) + error_cls: Type["ServiceBusError"] + if isinstance(exception, AuthenticationException): + logger.info("AMQP Connection authentication error occurred: (%r).", exception) + if custom_endpoint_address: + # for uamqp parity, invalid custom endpoint address raises ServiceBusConnectionError + error_cls = ServiceBusConnectionError + else: + error_cls = ServiceBusAuthenticationError + # elif isinstance(exception, MessageException): + # logger.info("AMQP Message error occurred: (%r).", exception) + # if isinstance(exception, MessageAlreadySettled): + # error_cls = MessageAlreadySettled + # elif isinstance(exception, MessageContentTooLarge): + # error_cls = MessageSizeExceededError + elif condition == ErrorCondition.NotFound: + # handle NotFound error code + error_cls = ( + ServiceBusCommunicationError + if isinstance(exception, AMQPConnectionError) + else MessagingEntityNotFoundError + ) + elif condition == ErrorCondition.ClientError and "timed out" in str(exception): + # handle send timeout + error_cls = OperationTimeoutError + elif condition == ErrorCondition.UnknownError or isinstance(exception, AMQPConnectionError): + error_cls = ServiceBusConnectionError + else: + error_cls = _ERROR_CODE_TO_ERROR_MAPPING.get(cast(bytes, condition), ServiceBusError) + + error = error_cls( + message=description, + error=exception, + condition=condition, + status_code=status_code, + ) + if condition in _ServiceBusErrorPolicy.no_retry: + error._retryable = False # pylint: disable=protected-access + else: + error._retryable = True # pylint: disable=protected-access + + return error + + @staticmethod + def handle_amqp_mgmt_error( + logger: "Logger", + error_description: "str", + condition: Optional["ErrorCondition"] = None, + description: Optional[str] = None, + status_code: Optional[str] = None + ) -> "ServiceBusError": + if description: + error_description += f" {description}." + + raise PyamqpTransport._handle_amqp_exception_with_condition( + logger, + condition, + description=error_description, + exception=None, + status_code=status_code, + ) + + @staticmethod + def create_servicebus_exception( + logger: "Logger", exception: Exception, *, custom_endpoint_address: Optional[str] = None + ) -> "ServiceBusError": + if isinstance(exception, AMQPException): + # handling AMQP Errors that have the condition field + condition = exception.condition + description = exception.description + exception = PyamqpTransport._handle_amqp_exception_with_condition( + logger, + condition, + description, + exception=exception, + custom_endpoint_address=custom_endpoint_address + ) + elif not isinstance(exception, ServiceBusError): + logger.exception( + "Unexpected error occurred (%r). Handler shutting down.", exception + ) + exception = ServiceBusError( + message=f"Handler failed: {exception}.", error=exception + ) + + return exception diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_uamqp_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_uamqp_transport.py new file mode 100644 index 0000000000000..eac376013c23a --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_uamqp_transport.py @@ -0,0 +1,1099 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=too-many-lines +import time +import functools +import datetime +from datetime import timezone +from typing import ( + Optional, + List, + TYPE_CHECKING, + Any, + Callable, + Dict, + Union, + Iterator, + Type, + cast, + Iterable, +) + +try: + from uamqp import ( + BatchMessage, + constants, + MessageBodyType, + Message, + types, + SendClient, + ReceiveClient, + Source, + compat, + Connection, + __version__, + ) + from uamqp.authentication import JWTTokenAuth + from uamqp.constants import ErrorCodes as AMQPErrorCodes + from uamqp.message import ( + MessageHeader, + MessageProperties, + ) + from uamqp.errors import ( + AMQPConnectionError, + AMQPError, + AuthenticationException, + ErrorAction, + ErrorPolicy, + MessageAlreadySettled, + MessageContentTooLarge, + MessageException, + ) + from ._base import AmqpTransport + from ..amqp._constants import AmqpMessageBodyType + from .._common.utils import utc_from_timestamp, utc_now + from .._common.tracing import get_receive_links, receive_trace_context_manager + from .._common.constants import ( + UAMQP_LIBRARY, + DATETIMEOFFSET_EPOCH, + RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION, + RECEIVER_LINK_DEAD_LETTER_REASON, + DEADLETTERNAME, + MAX_ABSOLUTE_EXPIRY_TIME, + MAX_DURATION_VALUE, + MESSAGE_COMPLETE, + MESSAGE_ABANDON, + MESSAGE_DEFER, + MESSAGE_DEAD_LETTER, + SESSION_FILTER, + SESSION_LOCKED_UNTIL, + ERROR_CODE_SESSION_LOCK_LOST, + ERROR_CODE_MESSAGE_LOCK_LOST, + ERROR_CODE_MESSAGE_NOT_FOUND, + ERROR_CODE_TIMEOUT, + ERROR_CODE_AUTH_FAILED, + ERROR_CODE_SESSION_CANNOT_BE_LOCKED, + ERROR_CODE_SERVER_BUSY, + ERROR_CODE_ARGUMENT_ERROR, + ERROR_CODE_OUT_OF_RANGE, + ERROR_CODE_ENTITY_DISABLED, + ERROR_CODE_ENTITY_ALREADY_EXISTS, + ERROR_CODE_PRECONDITION_FAILED, + ServiceBusReceiveMode, + ) + + from ..exceptions import ( + MessageSizeExceededError, + ServiceBusQuotaExceededError, + ServiceBusAuthorizationError, + ServiceBusError, + ServiceBusConnectionError, + ServiceBusCommunicationError, + MessageAlreadySettled, + MessageLockLostError, + MessageNotFoundError, + MessagingEntityDisabledError, + MessagingEntityNotFoundError, + MessagingEntityAlreadyExistsError, + ServiceBusServerBusyError, + ServiceBusAuthenticationError, + SessionCannotBeLockedError, + SessionLockLostError, + OperationTimeoutError, + ) + + if TYPE_CHECKING: + from uamqp import AMQPClient, Target + from logging import Logger + from ..amqp import ( + AmqpAnnotatedMessage, + AmqpMessageHeader, + AmqpMessageProperties, + ) + from .._servicebus_receiver import ServiceBusReceiver + from .._servicebus_sender import ServiceBusSender + from ..aio._servicebus_sender_async import ( + ServiceBusSender as ServiceBusSenderAsync, + ) + from .._common.message import ( + ServiceBusReceivedMessage, + ServiceBusMessage, + ServiceBusMessageBatch, + ) + from .._common._configuration import Configuration + + _NO_RETRY_CONDITION_ERROR_CODES = ( + constants.ErrorCodes.DecodeError, + constants.ErrorCodes.LinkMessageSizeExceeded, + constants.ErrorCodes.NotFound, + constants.ErrorCodes.NotImplemented, + constants.ErrorCodes.LinkRedirect, + constants.ErrorCodes.NotAllowed, + constants.ErrorCodes.UnauthorizedAccess, + constants.ErrorCodes.LinkStolen, + constants.ErrorCodes.ResourceLimitExceeded, + constants.ErrorCodes.ConnectionRedirect, + constants.ErrorCodes.PreconditionFailed, + constants.ErrorCodes.InvalidField, + constants.ErrorCodes.ResourceDeleted, + constants.ErrorCodes.IllegalState, + constants.ErrorCodes.FrameSizeTooSmall, + constants.ErrorCodes.ConnectionFramingError, + constants.ErrorCodes.SessionUnattachedHandle, + constants.ErrorCodes.SessionHandleInUse, + constants.ErrorCodes.SessionErrantLink, + constants.ErrorCodes.SessionWindowViolation, + ERROR_CODE_SESSION_LOCK_LOST, + ERROR_CODE_MESSAGE_LOCK_LOST, + ERROR_CODE_OUT_OF_RANGE, + ERROR_CODE_ARGUMENT_ERROR, + ERROR_CODE_PRECONDITION_FAILED, + ) + + _ERROR_CODE_TO_ERROR_MAPPING = { + constants.ErrorCodes.LinkMessageSizeExceeded: MessageSizeExceededError, + constants.ErrorCodes.ResourceLimitExceeded: ServiceBusQuotaExceededError, + constants.ErrorCodes.UnauthorizedAccess: ServiceBusAuthorizationError, + constants.ErrorCodes.NotImplemented: ServiceBusError, + constants.ErrorCodes.NotAllowed: ServiceBusError, + constants.ErrorCodes.LinkDetachForced: ServiceBusConnectionError, + ERROR_CODE_MESSAGE_LOCK_LOST: MessageLockLostError, + ERROR_CODE_MESSAGE_NOT_FOUND: MessageNotFoundError, + ERROR_CODE_AUTH_FAILED: ServiceBusAuthorizationError, + ERROR_CODE_ENTITY_DISABLED: MessagingEntityDisabledError, + ERROR_CODE_ENTITY_ALREADY_EXISTS: MessagingEntityAlreadyExistsError, + ERROR_CODE_SERVER_BUSY: ServiceBusServerBusyError, + ERROR_CODE_SESSION_CANNOT_BE_LOCKED: SessionCannotBeLockedError, + ERROR_CODE_SESSION_LOCK_LOST: SessionLockLostError, + ERROR_CODE_ARGUMENT_ERROR: ServiceBusError, + ERROR_CODE_OUT_OF_RANGE: ServiceBusError, + ERROR_CODE_TIMEOUT: OperationTimeoutError, + } + + def _error_handler(error): + """Handle connection and service errors. + + Called internally when a message has failed to send so we + can parse the error to determine whether we should attempt + to retry sending the message again. + Returns the action to take according to error type. + + :param error: The error received in the send attempt. + :type error: Exception + :rtype: ~uamqp.errors.ErrorAction + """ + if error.condition == b"com.microsoft:server-busy": + return ErrorAction(retry=True, backoff=4) + if error.condition == b"com.microsoft:timeout": + return ErrorAction(retry=True, backoff=2) + if error.condition == b"com.microsoft:operation-cancelled": + return ErrorAction(retry=True) + if error.condition == b"com.microsoft:container-close": + return ErrorAction(retry=True, backoff=4) + if error.condition in _NO_RETRY_CONDITION_ERROR_CODES: + return ErrorAction(retry=False) + return ErrorAction(retry=True) + + class _ServiceBusErrorPolicy(ErrorPolicy): + def __init__(self, max_retries=3, is_session=False): + self._is_session = is_session + super(_ServiceBusErrorPolicy, self).__init__( + max_retries=max_retries, on_error=_error_handler + ) + + def on_unrecognized_error(self, error): + if self._is_session: + return ErrorAction(retry=False) + return super(_ServiceBusErrorPolicy, self).on_unrecognized_error(error) + + def on_link_error(self, error): + if self._is_session: + return ErrorAction(retry=False) + return super(_ServiceBusErrorPolicy, self).on_link_error(error) + + def on_connection_error(self, error): + if self._is_session: + return ErrorAction(retry=False) + return super(_ServiceBusErrorPolicy, self).on_connection_error(error) + + class UamqpTransport(AmqpTransport): # pylint: disable=too-many-public-methods + """ + Class which defines uamqp-based methods used by the sender and receiver. + """ + + KIND = "uamqp" + + # define constants + MAX_FRAME_SIZE_BYTES = constants.MAX_FRAME_SIZE_BYTES + MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + TIMEOUT_FACTOR = 1000 + # CONNECTION_CLOSING_STATES: Tuple = ( # pylint:disable=protected-access + # c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member + # c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member + # c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member + # c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member + # ) + TRANSPORT_IDENTIFIER = f"{UAMQP_LIBRARY}/{__version__}" + + # To enable extensible string enums for the public facing parameter + # and translate to the "real" uamqp constants. + ServiceBusToAMQPReceiveModeMap = { + ServiceBusReceiveMode.PEEK_LOCK: constants.ReceiverSettleMode.PeekLock, + ServiceBusReceiveMode.RECEIVE_AND_DELETE: constants.ReceiverSettleMode.ReceiveAndDelete, + } + + # define symbols + PRODUCT_SYMBOL = types.AMQPSymbol("product") + VERSION_SYMBOL = types.AMQPSymbol("version") + FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") + PLATFORM_SYMBOL = types.AMQPSymbol("platform") + USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") + + # amqp value types + AMQP_LONG_VALUE: Callable = types.AMQPLong + AMQP_ARRAY_VALUE: Callable = types.AMQPArray + AMQP_UINT_VALUE: Callable = types.AMQPuInt + + # errors + TIMEOUT_ERROR = compat.TimeoutException + + @staticmethod + def build_message(**kwargs: Any) -> "Message": + """ + Creates a uamqp.Message with given arguments. + :rtype: uamqp.Message + """ + return Message(**kwargs) + + @staticmethod + def build_batch_message(data: List) -> "BatchMessage": + """ + Creates a uamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage + """ + return BatchMessage(data=data) + + @staticmethod + def get_message_delivery_tag( + message: "Message", _ + ) -> str: # pylint: disable=unused-argument + """ + Gets delivery tag of a Message. + :param message: Message to get delivery_tag from for uamqp.Message. + :param frame: Message to get delivery_tag from for pyamqp.Message. + :rtype: str + """ + return message.delivery_tag + + @staticmethod + def get_message_delivery_id( + message: "Message", _ + ) -> str: # pylint: disable=unused-argument + """ + Gets delivery id of a Message. + :param message: Message to get delivery_id from for uamqp.Message. + :param frame: Message to get delivery_id from for pyamqp.Message. + :rtype: str + """ + return message.delivery_no + + @staticmethod + def to_outgoing_amqp_message( + annotated_message: "AmqpAnnotatedMessage", + ) -> "Message": + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message + """ + message_header = None + ttl_set = False + header_vals = ( + annotated_message.header.values() if annotated_message.header else None + ) + # If header and non-None header values, create outgoing header. + if header_vals and header_vals.count(None) != len(header_vals): + annotated_message.header = cast( + "AmqpMessageHeader", annotated_message.header + ) + message_header = MessageHeader() + message_header.delivery_count = annotated_message.header.delivery_count + message_header.time_to_live = annotated_message.header.time_to_live + message_header.first_acquirer = annotated_message.header.first_acquirer + message_header.durable = annotated_message.header.durable + message_header.priority = annotated_message.header.priority + if ( + annotated_message.header.time_to_live + and annotated_message.header.time_to_live != MAX_DURATION_VALUE + ): + ttl_set = True + creation_time_from_ttl = int( + time.mktime(datetime.datetime.now(timezone.utc).timetuple()) + * UamqpTransport.TIMEOUT_FACTOR + ) + absolute_expiry_time_from_ttl = int( + min( + MAX_ABSOLUTE_EXPIRY_TIME, + creation_time_from_ttl + + annotated_message.header.time_to_live, + ) + ) + + message_properties = None + properties_vals = ( + annotated_message.properties.values() + if annotated_message.properties + else None + ) + # If properties and non-None properties values, create outgoing properties. + if properties_vals and properties_vals.count(None) != len(properties_vals): + annotated_message.properties = cast( + "AmqpMessageProperties", annotated_message.properties + ) + creation_time = None + absolute_expiry_time = None + if ttl_set: + creation_time = creation_time_from_ttl + absolute_expiry_time = absolute_expiry_time_from_ttl + else: + if annotated_message.properties.creation_time: + creation_time = int(annotated_message.properties.creation_time) + if annotated_message.properties.absolute_expiry_time: + absolute_expiry_time = int( + annotated_message.properties.absolute_expiry_time + ) + + message_properties = MessageProperties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=creation_time, + absolute_expiry_time=absolute_expiry_time, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + encoding=annotated_message._encoding, # pylint: disable=protected-access + ) + elif ttl_set: + message_properties = MessageProperties( + creation_time=creation_time_from_ttl if ttl_set else None, + absolute_expiry_time=absolute_expiry_time_from_ttl + if ttl_set + else None, + ) + + # pylint: disable=protected-access + amqp_body_type = annotated_message.body_type + if amqp_body_type == AmqpMessageBodyType.DATA: + amqp_body_type = MessageBodyType.Data + amqp_body = list(cast(Iterable, annotated_message._data_body)) + elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: + amqp_body_type = MessageBodyType.Sequence + amqp_body = list(cast(Iterable, annotated_message._sequence_body)) + else: + amqp_body_type = MessageBodyType.Value + amqp_body = annotated_message._value_body + + return Message( + body=amqp_body, + body_type=amqp_body_type, + header=message_header, + properties=message_properties, + application_properties=annotated_message.application_properties, + annotations=annotated_message.annotations, + delivery_annotations=annotated_message.delivery_annotations, + footer=annotated_message.footer, + ) + + @staticmethod + def encode_message(message: "ServiceBusMessage") -> bytes: + """ + Encodes the outgoing uamqp.Message of the message. + :param ServiceBusMessage message: Message. + :rtype: bytes + """ + return cast("Message", message._message).encode_message() + + @staticmethod + def update_message_app_properties( + message: "Message", key: str, value: str + ) -> "Message": + """ + Adds the given key/value to the application properties of the message. + :param uamqp.Message message: Message. + :param str key: Key to set in application properties. + :param str Value: Value to set for key in application properties. + :rtype: uamqp.Message + """ + if not message.application_properties: + message.application_properties = {} + message.application_properties.setdefault(key, value) + return message + + @staticmethod + def get_batch_message_encoded_size(message: "BatchMessage") -> int: + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return message.gather()[0].get_message_encoded_size() + + @staticmethod + def get_message_encoded_size(message: "Message") -> int: + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message message: Message to get encoded size of. + :rtype: int + """ + return message.get_message_encoded_size() + + @staticmethod + def get_remote_max_message_size(handler: "AMQPClient") -> int: + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + return ( + handler.message_handler._link.peer_max_message_size + ) # pylint:disable=protected-access + + @staticmethod + def get_handler_link_name(handler: "AMQPClient") -> str: + """ + Returns link name. + :param AMQPClient handler: Client to get name of link from. + :rtype: str + """ + return handler.message_handler.name + + @staticmethod + def create_retry_policy( + config: "Configuration", *, is_session: bool = False + ) -> "_ServiceBusErrorPolicy": + """ + Creates the error retry policy. + :param ~azure.servicebus._common._configuration.Configuration config: + Configuration. + :keyword bool is_session: Is session enabled. + """ + # TODO: What's the retry overlap between servicebus and pyamqp? + return _ServiceBusErrorPolicy( + max_retries=config.retry_total, is_session=is_session + ) + + @staticmethod + def create_connection( + host: str, auth: "JWTTokenAuth", network_trace: bool, **kwargs: Any + ) -> "Connection": + """ + Creates and returns the uamqp Connection object. + :param str host: The hostname, used by uamqp. + :param JWTTokenAuth auth: The auth, used by uamqp. + :param bool network_trace: Required. + """ + custom_endpoint_address = kwargs.pop( # pylint:disable=unused-variable + "custom_endpoint_address" + ) + ssl_opts = kwargs.pop("ssl_opts") # pylint:disable=unused-variable + transport_type = kwargs.pop( # pylint:disable=unused-variable + "transport_type" + ) + http_proxy = kwargs.pop("http_proxy") # pylint:disable=unused-variable + return Connection( + hostname=host, + sasl=auth, + debug=network_trace, + ) + + @staticmethod + def close_connection(connection: "Connection") -> None: + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + connection.destroy() + + @staticmethod + def create_send_client(config: "Configuration", **kwargs: Any) -> "SendClient": + """ + Creates and returns the uamqp SendClient. + :param ~azure.servicebus._common._configuration.Configuration config: + The configuration. + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + + return SendClient( + target, + debug=config.logging_enable, + error_policy=retry_policy, + keep_alive_interval=config.keep_alive, + encoding=config.encoding, + **kwargs, + ) + + @staticmethod + def set_msg_timeout( + sender: Union["ServiceBusSender", "ServiceBusSenderAsync"], + logger: "Logger", + timeout: int, + last_exception: Optional[Exception], + ) -> None: + # pylint: disable=protected-access + if not timeout: + cast("SendClient", sender._handler)._msg_timeout = 0 + return + if timeout <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError(message="Send operation timed out") + logger.info("%r send operation timed out. (%r)", sender._name, error) + raise error + cast("SendClient", sender._handler)._msg_timeout = ( + timeout * UamqpTransport.TIMEOUT_FACTOR # type: ignore + ) + + @staticmethod + def send_messages( + sender: "ServiceBusSender", + message: Union["ServiceBusMessage", "ServiceBusMessageBatch"], + logger: "Logger", + timeout: int, + last_exception: Optional[Exception], + ) -> None: # pylint: disable=unused-argument + """ + Handles sending of service bus messages. + :param ~azure.servicebus.ServiceBusSender sender: The sender with handler + to send messages. + :param message: ServiceBusMessage with uamqp.Message to be sent. + :paramtype message: ~azure.servicebus.ServiceBusMessage or ~azure.servicebus.ServiceBusMessageBatch + :param int timeout: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + sender._open() + default_timeout = cast("SendClient", sender._handler)._msg_timeout + try: + UamqpTransport.set_msg_timeout(sender, logger, timeout, last_exception) + sender._handler.send_message(message._message) + finally: + UamqpTransport.set_msg_timeout(sender, logger, default_timeout, None) + + @staticmethod + def add_batch( + sb_message_batch: "ServiceBusMessageBatch", + outgoing_sb_message: "ServiceBusMessage", + ) -> None: # pylint: disable=unused-argument + """ + Add ServiceBusMessage to the data body of the BatchMessage. + :param sb_message_batch: ServiceBusMessageBatch to add data to. + :param outgoing_sb_message: Transformed ServiceBusMessage for sending. + :rtype: None + """ + # pylint: disable=protected-access + sb_message_batch._message._body_gen.append(outgoing_sb_message._message) + + @staticmethod + def create_source(source: "Source", session_filter: Optional[str]) -> "Source": + """ + Creates and returns the Source. + + :param Source source: Required. + :param str or None session_id: Required. + """ + source = Source(source) + source.set_filter(session_filter, name=SESSION_FILTER, descriptor=None) + return source + + @staticmethod + def create_receive_client( + receiver: "ServiceBusReceiver", **kwargs: Any + ) -> "ReceiveClient": + """ + Creates and returns the receive client. + :param ~azure.servicebus._common._configuration.Configuration config: + The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword timeout: Required. + """ + source = kwargs.pop("source") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + receive_mode = kwargs.pop("receive_mode") + + return ReceiveClient( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + prefetch=link_credit, + auto_complete=False, + receive_settle_mode=UamqpTransport.ServiceBusToAMQPReceiveModeMap[ + receive_mode + ], + send_settle_mode=constants.SenderSettleMode.Settled + if receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE + else None, + on_attach=functools.partial(UamqpTransport.on_attach, receiver), + **kwargs, + ) + + @staticmethod + def on_attach( # pylint: disable=unused-argument + receiver: "ServiceBusReceiver", + source: "Source", + target: "Target", + properties: Dict[str, Any], + error: Exception, + ) -> None: + """ + Receiver on_attach callback. + """ + # pylint: disable=protected-access + if receiver._session and str(source) == receiver._entity_uri: + # This has to live on the session object so that autorenew has access to it. + receiver._session._session_start = utc_now() + expiry_in_seconds = properties.get(SESSION_LOCKED_UNTIL) + if expiry_in_seconds: + expiry_in_seconds = ( + expiry_in_seconds - DATETIMEOFFSET_EPOCH + ) / 10000000 + receiver._session._locked_until_utc = utc_from_timestamp( + expiry_in_seconds + ) + session_filter = source.get_filter(name=SESSION_FILTER) + receiver._session_id = session_filter.decode(receiver._config.encoding) + receiver._session._session_id = receiver._session_id + + @staticmethod + def iter_contextual_wrapper( + receiver: "ServiceBusReceiver", max_wait_time: Optional[int] = None + ) -> Iterator["ServiceBusReceivedMessage"]: + """The purpose of this wrapper is to allow both state restoration (for multiple concurrent iteration) + and per-iter argument passing that requires the former.""" + # pylint: disable=protected-access + original_timeout = None + while True: + # This is not threadsafe, but gives us a way to handle if someone passes + # different max_wait_times to different iterators and uses them in concert. + if max_wait_time: + original_timeout = receiver._handler._timeout + receiver._handler._timeout = ( + max_wait_time * UamqpTransport.TIMEOUT_FACTOR + ) + try: + message = receiver._inner_next() + links = get_receive_links(message) + with receive_trace_context_manager(receiver, links=links): + yield message + except StopIteration: + break + finally: + if original_timeout: + try: + receiver._handler._timeout = original_timeout + except AttributeError: # Handler may be disposed already. + pass + + # wait_time used by pyamqp + @staticmethod + def iter_next( + receiver: "ServiceBusReceiver", wait_time: Optional[int] = None + ) -> "ServiceBusReceivedMessage": # pylint: disable=unused-argument + # pylint: disable=protected-access + try: + receiver._receive_context.set() + receiver._open() + if not receiver._message_iter: + receiver._message_iter = receiver._handler.receive_messages_iter() + uamqp_message = next( + cast(Iterator["Message"], receiver._message_iter) + ) + message = receiver._build_received_message(uamqp_message) + if ( + receiver._auto_lock_renewer + and not receiver._session + and receiver._receive_mode + != ServiceBusReceiveMode.RECEIVE_AND_DELETE + ): + receiver._auto_lock_renewer.register(receiver, message) + return message + finally: + receiver._receive_context.clear() + + @staticmethod + def enhanced_message_received( + receiver: "ServiceBusReceiver", message: "Message" + ) -> None: + """ + Receiver enhanced_message_received callback. + """ + # pylint: disable=protected-access + cast("ReceiveClient", receiver._handler)._was_message_received = True + if receiver._receive_context.is_set(): + receiver._handler._received_messages.put(message) + else: + message.release() + + @staticmethod + def build_received_message( + receiver: "ServiceBusReceiver", + message_type: Type["ServiceBusReceivedMessage"], + received: "Message", + ) -> "ServiceBusReceivedMessage": + # pylint: disable=protected-access + message = message_type( + message=received, + receive_mode=receiver._receive_mode, + receiver=receiver, + amqp_transport=receiver._amqp_transport + ) + message._uamqp_message = received + receiver._last_received_sequenced_number = message.sequence_number + return message + + @staticmethod + def get_current_time(handler: "ReceiveClient") -> float: + """ + Gets the current time. + :param ReceiveClient handler: Handler with counter to get time. + :rtype: int + """ + return handler._counter.get_current_ms() # pylint: disable=protected-access + + @staticmethod + def reset_link_credit(handler: "ReceiveClient", link_credit: int) -> None: + """ + Resets the link credit on the link. + :param ReceiveClient handler: Client with link to reset link credit. + :param int link_credit: Link credit needed. + :rtype: None + """ + handler.message_handler.reset_link_credit(link_credit) + + # Executes message settlement, implementation is in settle_message_via_receiver_link_impl + # May be able to remove and just call methods in private method. + @staticmethod + def settle_message_via_receiver_link( + handler: "ReceiveClient", + message: "ServiceBusReceivedMessage", + settle_operation: str, + dead_letter_reason: Optional[str] = None, + dead_letter_error_description: Optional[str] = None, + ) -> None: # pylint: disable=unused-argument + UamqpTransport.settle_message_via_receiver_link_impl( + handler, + message, + settle_operation, + dead_letter_reason, + dead_letter_error_description, + )() + + @staticmethod + def settle_message_via_receiver_link_impl( + _: ReceiveClient, + message: "ServiceBusReceivedMessage", + settle_operation: str, + dead_letter_reason: Optional[str] = None, + dead_letter_error_description: Optional[str] = None, + ) -> Callable: # pylint: disable=unused-argument + # pylint: disable=protected-access + message._message = cast(Message, message._message) + if settle_operation == MESSAGE_COMPLETE: + return functools.partial(message._message.accept) + if settle_operation == MESSAGE_ABANDON: + return functools.partial(message._message.modify, True, False) + if settle_operation == MESSAGE_DEAD_LETTER: + return functools.partial( + message._message.reject, + condition=DEADLETTERNAME, + description=dead_letter_error_description, + info={ + RECEIVER_LINK_DEAD_LETTER_REASON: dead_letter_reason, + RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION: dead_letter_error_description, + }, + ) + if settle_operation == MESSAGE_DEFER: + return functools.partial(message._message.modify, True, True) + raise ValueError(f"Unsupported settle operation type: {settle_operation}") + + @staticmethod + def parse_received_message( + message: "Message", + message_type: Type["ServiceBusReceivedMessage"], + **kwargs: Any, + ) -> List["ServiceBusReceivedMessage"]: + """ + Parses peek/deferred op messages into ServiceBusReceivedMessage. + :param Message message: Message to parse. + :param ServiceBusReceivedMessage message_type: Message type to pass parsed message to. + :keyword ServiceBusReceiver receiver: Required. + :keyword bool is_peeked_message: Optional. For peeked messages. + :keyword bool is_deferred_message: Optional. For deferred messages. + :keyword ServiceBusReceiveMode receive_mode: Optional. + """ + parsed = [] + for m in message.get_data()[b"messages"]: + wrapped = Message.decode_from_bytes(bytearray(m[b"message"])) + parsed.append(message_type(wrapped, **kwargs)) + return parsed + + @staticmethod + def get_message_value(message: "Message") -> Any: + return message.get_data() + + @staticmethod + def create_token_auth( + auth_uri: str, + get_token: Callable, + token_type: bytes, + config: "Configuration", + **kwargs: Any, + ) -> "JWTTokenAuth": + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.servicebus._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 0 if update_token else 300 + + token_auth = JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window, + ) + if update_token: + token_auth.update_token() + return token_auth + + @staticmethod + def create_mgmt_msg( + message: "Message", + application_properties: Dict[str, Any], + config: "Configuration", # pylint:disable=unused-argument + reply_to: str, + **kwargs: Any, + ) -> "Message": + """ + :param message: The message to send in the management request. + :paramtype message: Any + :param Dict[bytes, str] application_properties: App props. + :param ~azure.servicebus._common._configuration.Configuration config: Configuration. + :param str reply_to: Reply to. + :rtype: uamqp.Message + """ + return Message( # type: ignore # TODO: fix mypy error + body=message, + properties=MessageProperties( + reply_to=reply_to, encoding=config.encoding, **kwargs + ), + application_properties=application_properties, + ) + + @staticmethod + def mgmt_client_request( + mgmt_client: "AMQPClient", + mgmt_msg: "Message", + *, + operation: bytes, + operation_type: bytes, + node: bytes, + timeout: int, + callback: Callable, + ) -> "ServiceBusReceivedMessage": + """ + Send mgmt request and return result of callback. + :param AMQPClient mgmt_client: Client to send request with. + :param Message mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword bytes operation_type: Op type. + :keyword bytes node: Mgmt target. + :keyword int timeout: Timeout. + :keyword Callable callback: Callback to process request response. + """ + return mgmt_client.mgmt_request( + mgmt_msg, + operation, + op_type=operation_type, + node=node, + timeout=timeout * UamqpTransport.TIMEOUT_FACTOR if timeout else None, + callback=functools.partial(callback, amqp_transport=UamqpTransport), + ) + + @staticmethod + def _handle_amqp_exception_with_condition( + logger: "Logger", + condition: Optional["AMQPErrorCodes"], + description: str, + exception: Optional["AMQPError"] = None, + status_code: Optional[str] = None, + ) -> "ServiceBusError": + # handling AMQP Errors that have the condition field or the mgmt handler + logger.info( + "AMQP error occurred: (%r), condition: (%r), description: (%r).", + exception, + condition, + description, + ) + error_cls: Type["ServiceBusError"] + if condition == AMQPErrorCodes.NotFound: + # handle NotFound error code + error_cls = ( + ServiceBusCommunicationError + if isinstance(exception, AMQPConnectionError) + else MessagingEntityNotFoundError + ) + elif condition == AMQPErrorCodes.ClientError and "timed out" in str( + exception + ): + # handle send timeout + error_cls = OperationTimeoutError + elif condition == AMQPErrorCodes.UnknownError and isinstance( + exception, AMQPConnectionError + ): + error_cls = ServiceBusConnectionError + else: + # handle other error codes + error_cls = _ERROR_CODE_TO_ERROR_MAPPING.get(condition, ServiceBusError) + + error = error_cls( + message=description, + error=exception, + condition=condition, + status_code=status_code, + ) + if condition in _NO_RETRY_CONDITION_ERROR_CODES: + error._retryable = False # pylint: disable=protected-access + else: + error._retryable = True # pylint: disable=protected-access + + return error + + @staticmethod + def _handle_amqp_exception_without_condition( + logger: "Logger", exception: "AMQPError" + ) -> "ServiceBusError": + error_cls: Type[ServiceBusError] = ServiceBusError + if isinstance(exception, AMQPConnectionError): + logger.info("AMQP Connection error occurred: (%r).", exception) + error_cls = ServiceBusConnectionError + elif isinstance(exception, AuthenticationException): + logger.info( + "AMQP Connection authentication error occurred: (%r).", exception + ) + error_cls = ServiceBusAuthenticationError + elif isinstance(exception, MessageException): + logger.info("AMQP Message error occurred: (%r).", exception) + if isinstance(exception, MessageAlreadySettled): + error_cls = MessageAlreadySettled + elif isinstance(exception, MessageContentTooLarge): + error_cls = MessageSizeExceededError + else: + logger.info( + "Unexpected AMQP error occurred (%r). Handler shutting down.", + exception, + ) + + error = error_cls(message=str(exception), error=exception) + return error + + @staticmethod + def handle_amqp_mgmt_error( + logger: "Logger", + error_description: "str", + condition: Optional["AMQPErrorCodes"] = None, + description: Optional[str] = None, + status_code: Optional[str] = None, + ) -> "ServiceBusError": + if description: + error_description += f" {description}." + + raise UamqpTransport._handle_amqp_exception_with_condition( + logger, + condition, + description=error_description, + exception=None, + status_code=status_code, + ) + + @staticmethod + def create_servicebus_exception( + logger: "Logger", + exception: Exception, + *, + custom_endpoint_address: Optional[str] = None # pylint: disable=unused-argument + ) -> "ServiceBusError": + if isinstance(exception, AMQPError): + try: + # handling AMQP Errors that have the condition field + condition = exception.condition + description = exception.description + exception = UamqpTransport._handle_amqp_exception_with_condition( + logger, condition, description, exception=exception + ) + except AttributeError: + # handling AMQP Errors that don't have the condition field + exception = UamqpTransport._handle_amqp_exception_without_condition( + logger, exception + ) + elif not isinstance(exception, ServiceBusError): + logger.exception( + "Unexpected error occurred (%r). Handler shutting down.", exception + ) + exception = ServiceBusError( + message=f"Handler failed: {exception}.", error=exception + ) + + return exception + +except ImportError: + pass diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_version.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_version.py index a825990033771..a606c9c704bea 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_version.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "7.9.1" +VERSION = "7.10.0" diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_auto_lock_renewer.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_auto_lock_renewer.py index 0bc36d7b47333..242e20d1c392c 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_auto_lock_renewer.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_auto_lock_renewer.py @@ -68,7 +68,7 @@ def __init__( ) -> None: self._internal_kwargs = get_dict_with_loop_if_needed(loop) self._shutdown = asyncio.Event() - self._futures: List[asyncio.Future] = [] + self._futures = [] # type: List[asyncio.Future] self._sleep_time = 1 self._renew_period = 10 self._on_lock_renew_failure = on_lock_renew_failure @@ -118,7 +118,7 @@ async def _auto_lock_renew( _log.debug( "Running async lock auto-renew for %r seconds", max_lock_renewal_duration ) - error: Optional[Exception] = None + error = None # type: Optional[Exception] clean_shutdown = False # Only trigger the on_lock_renew_failure if halting was not expected (shutdown, etc) renew_period = renew_period_override or self._renew_period try: diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_utils.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_utils.py index 27e99d88d4d47..b429888d754b7 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_utils.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_utils.py @@ -3,20 +3,15 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- -from __future__ import annotations + import sys import asyncio import logging import functools -from typing import Dict, Optional, TYPE_CHECKING - -from uamqp import authentication from .._common.constants import JWT_TOKEN_SCOPE, TOKEN_TYPE_JWT, TOKEN_TYPE_SASTOKEN -if TYPE_CHECKING: - from ..aio._servicebus_client_async import ServiceBusClient _log = logging.getLogger(__name__) @@ -42,7 +37,7 @@ def get_running_loop(): return asyncio.get_event_loop() -async def create_authentication(client: ServiceBusClient): +async def create_authentication(client): # pylint: disable=protected-access try: # ignore mypy's warning because token_type is Optional @@ -50,45 +45,26 @@ async def create_authentication(client: ServiceBusClient): except AttributeError: token_type = TOKEN_TYPE_JWT if token_type == TOKEN_TYPE_SASTOKEN: - auth = authentication.JWTTokenAsync( + return (await client._amqp_transport.create_token_auth_async( client._auth_uri, + get_token=functools.partial(client._credential.get_token, client._auth_uri), + token_type=token_type, + config=client._config, + update_token=True + )) + return (await client._amqp_transport.create_token_auth_async( client._auth_uri, - # Since we have handled the token type, the type is already narrowed. - functools.partial(client._credential.get_token, client._auth_uri), # type: ignore + get_token=functools.partial(client._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=client._config.auth_timeout, - http_proxy=client._config.http_proxy, - transport_type=client._config.transport_type, - custom_endpoint_hostname=client._config.custom_endpoint_hostname, - port=client._config.connection_port, - verify=client._config.connection_verify - ) - await auth.update_token() - return auth - return authentication.JWTTokenAsync( - client._auth_uri, - client._auth_uri, - # Same as mentioned above. - functools.partial(client._credential.get_token, JWT_TOKEN_SCOPE), # type: ignore - token_type=token_type, - timeout=client._config.auth_timeout, - http_proxy=client._config.http_proxy, - transport_type=client._config.transport_type, - refresh_window=300, - custom_endpoint_hostname=client._config.custom_endpoint_hostname, - port=client._config.connection_port, - verify=client._config.connection_verify - ) + config=client._config, + update_token=False, + )) -def get_dict_with_loop_if_needed( - loop: Optional[asyncio.AbstractEventLoop], -) -> Dict[str, asyncio.AbstractEventLoop]: +def get_dict_with_loop_if_needed(loop): if sys.version_info >= (3, 10): if loop: - raise ValueError( - "Starting Python 3.10, asyncio no longer supports loop as a parameter." - ) + raise ValueError("Starting Python 3.10, asyncio no longer supports loop as a parameter.") elif loop: - return {"loop": loop} + return {'loop': loop} return {} diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py index 284f41afed472..f037145c9281a 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py @@ -8,27 +8,12 @@ import time from typing import TYPE_CHECKING, Any, Callable, Optional, Dict, Union -import uamqp -from uamqp import compat -from uamqp.message import MessageProperties - -from azure.core.credentials import ( - AccessToken, - AzureSasCredential, - AzureNamedKeyCredential, -) +from azure.core.credentials import AccessToken, AzureSasCredential, AzureNamedKeyCredential -from .._base_handler import ( - _generate_sas_token, - BaseHandler as BaseHandlerSync, - _get_backoff_time, -) +from ._transport._pyamqp_transport_async import PyamqpTransportAsync +from .._base_handler import _generate_sas_token, BaseHandler as BaseHandlerSync, _get_backoff_time from .._common._configuration import Configuration -from .._common.utils import ( - create_properties, - strip_protocol_from_uri, - parse_sas_credential, -) +from .._common.utils import create_properties, strip_protocol_from_uri, parse_sas_credential from .._common.constants import ( TOKEN_TYPE_SASTOKEN, MGMT_REQUEST_OP_TYPE_ENTITY_MGMT, @@ -40,10 +25,16 @@ ServiceBusConnectionError, SessionLockLostError, OperationTimeoutError, - _create_servicebus_exception, ) if TYPE_CHECKING: + try: + # pylint:disable=unused-import + from uamqp.async_ops.client_async import AMQPClientAsync as uamqp_AMQPClientAsync + except ImportError: + pass + from .._pyamqp.aio._client_async import AMQPClientAsync as pyamqp_AMQPClientAsync + from .._pyamqp.message import Message as pyamqp_Message from azure.core.credentials_async import AsyncTokenCredential _LOGGER = logging.getLogger(__name__) @@ -99,8 +90,7 @@ class ServiceBusAzureNamedKeyTokenCredentialAsync(object): :type credential: ~azure.core.credentials.AzureNamedKeyCredential """ - def __init__(self, azure_named_key_credential): - # type: (AzureNamedKeyCredential) -> None + def __init__(self, azure_named_key_credential: AzureNamedKeyCredential) -> None: self._credential = azure_named_key_credential self.token_type = b"servicebus.windows.net:sastoken" @@ -117,14 +107,11 @@ class ServiceBusAzureSasTokenCredentialAsync(object): :param azure_sas_credential: The credential to be used for authentication. :type azure_sas_credential: ~azure.core.credentials.AzureSasCredential """ - def __init__(self, azure_sas_credential: AzureSasCredential) -> None: self._credential = azure_sas_credential self.token_type = TOKEN_TYPE_SASTOKEN - async def get_token( - self, *scopes: str, **kwargs: Any # pylint:disable=unused-argument - ) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument """ This method is automatically called when token is about to expire. """ @@ -137,11 +124,11 @@ def __init__( self, fully_qualified_namespace: str, entity_name: str, - credential: Union[ - "AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential - ], + credential: Union["AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential], **kwargs: Any ) -> None: + self._amqp_transport = kwargs.pop("amqp_transport", PyamqpTransportAsync) + # If the user provided http:// or sb://, let's be polite and strip that. self.fully_qualified_namespace = strip_protocol_from_uri( fully_qualified_namespace.strip() @@ -152,23 +139,30 @@ def __init__( self._entity_path = self._entity_name + ( ("/Subscriptions/" + subscription_name) if subscription_name else "" ) - self._mgmt_target = "{}{}".format(self._entity_path, MANAGEMENT_PATH_SUFFIX) + self._mgmt_target = f"{self._entity_path}{MANAGEMENT_PATH_SUFFIX}" if isinstance(credential, AzureSasCredential): self._credential = ServiceBusAzureSasTokenCredentialAsync(credential) elif isinstance(credential, AzureNamedKeyCredential): - self._credential = ServiceBusAzureNamedKeyTokenCredentialAsync(credential) # type: ignore + self._credential = ServiceBusAzureNamedKeyTokenCredentialAsync(credential) # type: ignore else: - self._credential = credential # type: ignore + self._credential = credential # type: ignore self._container_id = CONTAINER_PREFIX + str(uuid.uuid4())[:8] - self._config = Configuration(**kwargs) + self._config = Configuration( + hostname=self.fully_qualified_namespace, + amqp_transport=self._amqp_transport, + **kwargs + ) self._running = False - self._handler = None # type: uamqp.AMQPClientAsync + self._handler: Optional[Union["uamqp_AMQPClientAsync", "pyamqp_AMQPClientAsync"]] = None self._auth_uri = None - self._properties = create_properties(self._config.user_agent) + self._properties = create_properties( + self._config.user_agent, + amqp_transport=self._amqp_transport, + ) self._shutdown = asyncio.Event() @classmethod - def _convert_connection_string_to_kwargs(cls, conn_str: str, **kwargs: Any): + def _convert_connection_string_to_kwargs(cls, conn_str, **kwargs): # pylint:disable=protected-access return BaseHandlerSync._convert_connection_string_to_kwargs( conn_str, @@ -186,12 +180,14 @@ async def __aenter__(self): await self._open_with_retry() return self - async def __aexit__(self, *args: Any): + async def __aexit__(self, *args): await self.close() async def _handle_exception(self, exception): # pylint: disable=protected-access - error = _create_servicebus_exception(_LOGGER, exception) + error = self._amqp_transport.create_servicebus_exception( + _LOGGER, exception, custom_endpoint_address=self._config.custom_endpoint_address + ) try: # If SessionLockLostError or ServiceBusConnectionError happen when a session receiver is running, @@ -244,10 +240,10 @@ def _check_live(self): async def _do_retryable_operation( self, - operation: Callable[..., Any], + operation: Callable, timeout: Optional[float] = None, **kwargs: Any - ): + ) -> Any: require_last_exception = kwargs.pop("require_last_exception", False) operation_requires_timeout = kwargs.pop("operation_requires_timeout", False) retried_times = 0 @@ -267,6 +263,9 @@ async def _do_retryable_operation( return await operation(**kwargs) except StopAsyncIteration: raise + except ImportError: + # If dependency is not installed, do not retry. + raise except Exception as exception: # pylint: disable=broad-except last_exception = await self._handle_exception(exception) if require_last_exception: @@ -315,12 +314,12 @@ async def _backoff( async def _mgmt_request_response( self, mgmt_operation: bytes, - message: uamqp.Message, - callback: Callable[..., Any], + message: Any, + callback: Callable, keep_alive_associated_link: bool = True, timeout: Optional[float] = None, **kwargs: Any - ) -> uamqp.Message: + ) -> "pyamqp_Message": """ Execute an amqp management operation. @@ -343,29 +342,31 @@ async def _mgmt_request_response( if keep_alive_associated_link: try: application_properties = { - ASSOCIATEDLINKPROPERTYNAME: self._handler.message_handler.name + ASSOCIATEDLINKPROPERTYNAME: self._amqp_transport.get_handler_link_name(self._handler) } except AttributeError: pass - mgmt_msg = uamqp.Message( - body=message, - properties=MessageProperties( - reply_to=self._mgmt_target, encoding=self._config.encoding, **kwargs - ), + mgmt_msg = self._amqp_transport.create_mgmt_msg( # type: ignore # TODO: fix mypy + message=message, application_properties=application_properties, + config=self._config, + reply_to=self._mgmt_target, + **kwargs ) + try: - return await self._handler.mgmt_request_async( + return await self._amqp_transport.mgmt_client_request_async( + self._handler, mgmt_msg, - mgmt_operation, - op_type=MGMT_REQUEST_OP_TYPE_ENTITY_MGMT, + operation=mgmt_operation, + operation_type=MGMT_REQUEST_OP_TYPE_ENTITY_MGMT, node=self._mgmt_target.encode(self._config.encoding), - timeout=timeout * 1000 if timeout else None, - callback=callback, + timeout=timeout, + callback=callback ) except Exception as exp: # pylint: disable=broad-except - if isinstance(exp, compat.TimeoutException): + if isinstance(exp, self._amqp_transport.TIMEOUT_ERROR): raise OperationTimeoutError(error=exp) raise @@ -373,7 +374,7 @@ async def _mgmt_request_response_with_retry( self, mgmt_operation: bytes, message: Dict[str, Any], - callback: Callable[..., Any], + callback: Callable, timeout: Optional[float] = None, **kwargs: Any ) -> Any: @@ -387,11 +388,6 @@ async def _mgmt_request_response_with_retry( **kwargs ) - def _add_span_request_attributes(self, span): - return BaseHandlerSync._add_span_request_attributes( # pylint: disable=protected-access - self, span - ) - async def _open(self): # pylint: disable=no-self-use raise ValueError("Subclass should override the method.") diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py index c3fe351b30b2f..25357ec7af8cd 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py @@ -6,10 +6,11 @@ import logging from weakref import WeakSet from typing_extensions import Literal +import certifi -import uamqp from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from ._transport._pyamqp_transport_async import PyamqpTransportAsync from .._base_handler import _parse_conn_str from ._base_handler_async import ( ServiceBusSharedKeyCredential, @@ -35,7 +36,7 @@ _LOGGER = logging.getLogger(__name__) -class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-keyword +class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """The ServiceBusClient class defines a high level interface for getting ServiceBusSender and ServiceBusReceiver. @@ -73,10 +74,13 @@ class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-key the Service Bus service, allowing network requests to be routed through any application gateways or other paths needed for the host environment. Default is None. The format would be like "sb://:". - If port is not specified in the custom_endpoint_address, by default port 443 will be used. + If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is + False and the Pure Python AMQP library will be used as the underlying transport. + :paramtype uamqp_transport: bool .. admonition:: Example: @@ -102,6 +106,13 @@ def __init__( retry_mode: str = "exponential", **kwargs: Any ) -> None: + uamqp_transport = kwargs.pop("uamqp_transport", False) + if uamqp_transport: + try: + from ._transport._uamqp_transport_async import UamqpTransportAsync + except ImportError: + raise ValueError("To use the uAMQP transport, please install `uamqp>=1.6.3,<2.0.0`.") + self._amqp_transport = UamqpTransportAsync if uamqp_transport else PyamqpTransportAsync # If the user provided http:// or sb://, let's be polite and strip that. self.fully_qualified_namespace = strip_protocol_from_uri( fully_qualified_namespace.strip() @@ -112,35 +123,40 @@ def __init__( retry_backoff_factor=retry_backoff_factor, retry_backoff_max=retry_backoff_max, retry_mode=retry_mode, + hostname=self.fully_qualified_namespace, + amqp_transport=self._amqp_transport, **kwargs ) self._connection = None # Optional entity name, can be the name of Queue or Topic. Intentionally not advertised, typically be needed. self._entity_name = kwargs.get("entity_name") - self._auth_uri = "sb://{}".format(self.fully_qualified_namespace) + self._auth_uri = f"sb://{self.fully_qualified_namespace}" if self._entity_name: - self._auth_uri = "{}/{}".format(self._auth_uri, self._entity_name) + self._auth_uri = f"{self._auth_uri}/{self._entity_name}" # Internal flag for switching whether to apply connection sharing, pending fix in uamqp library self._connection_sharing = False - self._handlers = WeakSet() # type: WeakSet - - self._custom_endpoint_address = kwargs.get("custom_endpoint_address") + self._handlers: WeakSet = WeakSet() + self._custom_endpoint_address = kwargs.get('custom_endpoint_address') self._connection_verify = kwargs.get("connection_verify") async def __aenter__(self): if self._connection_sharing: - await self._create_uamqp_connection() + await self._create_connection() return self async def __aexit__(self, *args): await self.close() - async def _create_uamqp_connection(self): + async def _create_connection(self): auth = await create_authentication(self) - self._connection = uamqp.ConnectionAsync( - hostname=self.fully_qualified_namespace, - sasl=auth, - debug=self._config.logging_enable, + self._connection = self._amqp_transport.create_connection_async( + host=self.fully_qualified_namespace, + auth=auth.sasl, + network_trace=self._config.logging_enable, + custom_endpoint_address=self._custom_endpoint_address, + ssl_opts={'ca_certs':self._connection_verify or certifi.where()}, + transport_type=self._config.transport_type, + http_proxy=self._config.http_proxy ) @classmethod @@ -184,6 +200,9 @@ def from_connection_string( :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is + False and the Pure Python AMQP library will be used as the underlying transport. + :paramtype uamqp_transport: bool :rtype: ~azure.servicebus.aio.ServiceBusClient .. admonition:: Example: @@ -233,7 +252,7 @@ async def close(self) -> None: self._handlers.clear() if self._connection_sharing and self._connection: - await self._connection.destroy_async() + await self._connection.close() def get_queue_sender(self, queue_name: str, **kwargs: Any) -> ServiceBusSender: """Get ServiceBusSender for the specific queue. @@ -277,6 +296,7 @@ def get_queue_sender(self, queue_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) @@ -390,6 +410,7 @@ def get_queue_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) @@ -436,6 +457,7 @@ def get_topic_sender(self, topic_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) @@ -549,6 +571,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) except ValueError: @@ -578,6 +601,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + amqp_transport=self._amqp_transport, **kwargs ) self._handlers.add(handler) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py index 425b73df29879..59199dfce7cb1 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py @@ -2,18 +2,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -# pylint:disable=too-many-lines +#pylint: disable=too-many-lines + import asyncio import collections import datetime import functools import logging +import time import warnings +from enum import Enum from typing import Any, List, Optional, AsyncIterator, Union, TYPE_CHECKING, cast -from uamqp import ReceiveClientAsync, types, Message -from uamqp.constants import SenderSettleMode - from ..exceptions import ServiceBusError from ._servicebus_session_async import ServiceBusSession from ._base_handler_async import BaseHandler @@ -41,19 +41,31 @@ MGMT_REQUEST_DEAD_LETTER_REASON, MGMT_REQUEST_DEAD_LETTER_ERROR_DESCRIPTION, MGMT_RESPONSE_MESSAGE_EXPIRATION, - SPAN_NAME_RECEIVE_DEFERRED, - SPAN_NAME_PEEK, - ServiceBusToAMQPReceiveModeMap, ) from .._common import mgmt_handlers -from .._common.utils import ( +from .._common.utils import utc_from_timestamp +from .._common.tracing import ( receive_trace_context_manager, - utc_from_timestamp, + settle_trace_context_manager, get_receive_links, + get_span_link_from_message, + SPAN_NAME_RECEIVE_DEFERRED, + SPAN_NAME_PEEK, ) -from ._async_utils import create_authentication, get_running_loop +from ._async_utils import create_authentication if TYPE_CHECKING: + try: + # pylint:disable=unused-import + from uamqp.async_ops.client_async import ReceiveClientAsync as uamqp_ReceiveClientAsync + from uamqp.authentication import JWTTokenAsync as uamqp_JWTTokenAuthAsync + from uamqp.message import Message as uamqp_Message + except ImportError: + pass + from ._transport._base_async import AmqpTransportAsync + from .._pyamqp.message import Message as pyamqp_Message + from .._pyamqp.aio import ReceiveClientAsync as pyamqp_ReceiveClientAsync + from .._pyamqp.aio._authentication_async import JWTTokenAuthAsync as pyamqp_JWTTokenAuthAsync from azure.core.credentials_async import AsyncTokenCredential from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from .._common.auto_lock_renewer import AutoLockRenewer @@ -125,9 +137,7 @@ class ServiceBusReceiver(collections.abc.AsyncIterator, BaseHandler, ReceiverMix def __init__( self, fully_qualified_namespace: str, - credential: Union[ - "AsyncTokenCredential", "AzureSasCredential", "AzureNamedKeyCredential" - ], + credential: Union["AsyncTokenCredential", "AzureSasCredential", "AzureNamedKeyCredential"], *, queue_name: Optional[str] = None, topic_name: Optional[str] = None, @@ -138,9 +148,13 @@ def __init__( max_wait_time: Optional[float] = None, auto_lock_renewer: Optional["AutoLockRenewer"] = None, prefetch_count: int = 0, - **kwargs: Any, + **kwargs: Any ) -> None: - self._message_iter: Optional[AsyncIterator[ServiceBusReceivedMessage]] = None + self._session_id = None + self._message_iter: Optional[AsyncIterator[Union["uamqp_Message", "pyamqp_Message"]]] = ( + None + ) + self._amqp_transport: "AmqpTransportAsync" if kwargs.get("entity_name"): super(ServiceBusReceiver, self).__init__( fully_qualified_namespace=fully_qualified_namespace, @@ -152,7 +166,7 @@ def __init__( max_wait_time=max_wait_time, auto_lock_renewer=auto_lock_renewer, prefetch_count=prefetch_count, - **kwargs, + **kwargs ) else: if queue_name and topic_name: @@ -181,7 +195,7 @@ def __init__( max_wait_time=max_wait_time, auto_lock_renewer=auto_lock_renewer, prefetch_count=prefetch_count, - **kwargs, + **kwargs ) self._populate_attributes( @@ -192,53 +206,36 @@ def __init__( max_wait_time=max_wait_time, auto_lock_renewer=auto_lock_renewer, prefetch_count=prefetch_count, - **kwargs, + **kwargs ) self._session = ( - None - if self._session_id is None - else ServiceBusSession(self._session_id, self) + None if self._session_id is None else ServiceBusSession(cast(str, self._session_id), self) ) self._receive_context = asyncio.Event() + self._handler: Union["pyamqp_ReceiveClientAsync", "uamqp_ReceiveClientAsync"] + self._build_received_message = functools.partial( + self._amqp_transport.build_received_message, + self, + ServiceBusReceivedMessage + ) - # Python 3.5 does not allow for yielding from a coroutine, so instead of the try-finally functional wrapper - # trick to restore the timeout, let's use a wrapper class to maintain the override that may be specified. - class _IterContextualWrapper(collections.abc.AsyncIterator): - def __init__(self, receiver, max_wait_time=None): - self.receiver = receiver - self.max_wait_time = max_wait_time - - async def __anext__(self): - # pylint: disable=protected-access - original_timeout = None - # This is not threadsafe, but gives us a way to handle if someone passes - # different max_wait_times to different iterators and uses them in concert. - if self.max_wait_time and self.receiver and self.receiver._handler: - original_timeout = self.receiver._handler._timeout - self.receiver._handler._timeout = self.max_wait_time * 1000 - try: - self.receiver._receive_context.set() - message = await self.receiver._inner_anext() - links = get_receive_links(message) - with receive_trace_context_manager(self.receiver, links=links): - return message - finally: - self.receiver._receive_context.clear() - if original_timeout: - try: - self.receiver._handler._timeout = original_timeout - except AttributeError: # Handler may be disposed already. - pass + self._iter_contextual_wrapper = functools.partial( + self._amqp_transport.iter_contextual_wrapper_async, self + ) + self._iter_next = functools.partial( + self._amqp_transport.iter_next_async, + self + ) def __aiter__(self): - return self._IterContextualWrapper(self) + return self._iter_contextual_wrapper() - async def _inner_anext(self): + async def _inner_anext(self, wait_time=None): # We do this weird wrapping such that an imperitive next() call, and a generator-based iter both trace sanely. self._check_live() while True: try: - return await self._do_retryable_operation(self._iter_next) + return await self._do_retryable_operation(self._iter_next, wait_time=wait_time) except StopAsyncIteration: self._message_iter = None raise @@ -253,20 +250,6 @@ async def __anext__(self): finally: self._receive_context.clear() - async def _iter_next(self): - await self._open() - if not self._message_iter: - self._message_iter = self._handler.receive_messages_iter_async() - uamqp_message = await self._message_iter.__anext__() - message = self._build_message(uamqp_message) - if ( - self._auto_lock_renewer - and not self._session - and self._receive_mode != ServiceBusReceiveMode.RECEIVE_AND_DELETE - ): - self._auto_lock_renewer.register(self, message) - return message - @classmethod def _from_connection_string( cls, conn_str: str, **kwargs: Any @@ -328,33 +311,31 @@ def _from_connection_string( ) return cls(**constructor_args) - def _create_handler(self, auth): - self._handler = ReceiveClientAsync( - self._get_source(), + def _create_handler(self, auth: Union["pyamqp_JWTTokenAuthAsync", "uamqp_JWTTokenAuthAsync"]) -> None: + + self._handler = self._amqp_transport.create_receive_client_async( + receiver=self, + source=self._get_source(), auth=auth, - debug=self._config.logging_enable, + network_trace=self._config.logging_enable, properties=self._properties, - error_policy=self._error_policy, + retry_policy=self._error_policy, client_name=self._name, - on_attach=self._on_attach, - auto_complete=False, - encoding=self._config.encoding, - receive_settle_mode=ServiceBusToAMQPReceiveModeMap[self._receive_mode], - send_settle_mode=SenderSettleMode.Settled - if self._receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE - else None, - timeout=self._max_wait_time * 1000 if self._max_wait_time else 0, - prefetch=self._prefetch_count, + receive_mode=self._receive_mode, + timeout=self._max_wait_time * self._amqp_transport.TIMEOUT_FACTOR + if self._max_wait_time + else 0, + link_credit=self._prefetch_count, # If prefetch is 1, then keep_alive coroutine serves as keep receiving for releasing messages keep_alive_interval=self._config.keep_alive if self._prefetch_count != 1 else 5, shutdown_after_timeout=False, - link_properties={CONSUMER_IDENTIFIER: self._name}, + link_properties = {CONSUMER_IDENTIFIER:self._name} ) if self._prefetch_count == 1: # pylint: disable=protected-access - self._handler._message_received = self._enhanced_message_received + self._amqp_transport.set_handler_message_received_async(self) async def _open(self): # pylint: disable=protected-access @@ -387,45 +368,37 @@ async def _receive( amqp_receive_client = self._handler received_messages_queue = amqp_receive_client._received_messages max_message_count = max_message_count or self._prefetch_count - timeout_ms = ( - 1000 * (timeout or self._max_wait_time) + timeout_seconds = ( + self._amqp_transport.TIMEOUT_FACTOR * (timeout or self._max_wait_time) if (timeout or self._max_wait_time) else 0 ) - abs_timeout_ms = ( - amqp_receive_client._counter.get_current_ms() + timeout_ms - if timeout_ms + abs_timeout = ( + self._amqp_transport.get_current_time(amqp_receive_client) + timeout_seconds + if timeout_seconds else 0 ) - batch: List[Message] = [] - while ( - not received_messages_queue.empty() and len(batch) < max_message_count - ): + batch: Union[List["uamqp_Message"], List["pyamqp_Message"]] = [] + while not received_messages_queue.empty() and len(batch) < max_message_count: batch.append(received_messages_queue.get()) received_messages_queue.task_done() if len(batch) >= max_message_count: - return [self._build_message(message) for message in batch] + return [self._build_received_message(message) for message in batch] # Dynamically issue link credit if max_message_count > 1 when the prefetch_count is the default value 1 - if ( - max_message_count - and self._prefetch_count == 1 - and max_message_count > 1 - ): + if max_message_count and self._prefetch_count == 1 and max_message_count > 1: link_credit_needed = max_message_count - len(batch) - await amqp_receive_client.message_handler.reset_link_credit_async( - link_credit_needed - ) + await self._amqp_transport.reset_link_credit_async(amqp_receive_client, link_credit_needed) first_message_received = expired = False receiving = True while receiving and not expired and len(batch) < max_message_count: while receiving and received_messages_queue.qsize() < max_message_count: if ( - abs_timeout_ms - and amqp_receive_client._counter.get_current_ms() - > abs_timeout_ms + abs_timeout + and self._amqp_transport.get_current_time(amqp_receive_client) + > abs_timeout ): expired = True break @@ -439,17 +412,16 @@ async def _receive( ): # first message(s) received, continue receiving for some time first_message_received = True - abs_timeout_ms = ( - amqp_receive_client._counter.get_current_ms() - + self._further_pull_receive_timeout_ms + abs_timeout = ( + self._amqp_transport.get_current_time(amqp_receive_client) + + self._further_pull_receive_timeout ) while ( - not received_messages_queue.empty() - and len(batch) < max_message_count + not received_messages_queue.empty() and len(batch) < max_message_count ): batch.append(received_messages_queue.get()) received_messages_queue.task_done() - return [self._build_message(message) for message in batch] + return [self._build_received_message(message) for message in batch] finally: self._receive_context.clear() @@ -480,16 +452,18 @@ async def _settle_message_with_retry( message="The lock on the message lock has expired.", error=message.auto_renew_error, ) - - await self._do_retryable_operation( - self._settle_message, - timeout=None, - message=message, - settle_operation=settle_operation, - dead_letter_reason=dead_letter_reason, - dead_letter_error_description=dead_letter_error_description, - ) - message._settled = True + link = get_span_link_from_message(message) + trace_links = [link] if link else [] + with settle_trace_context_manager(self, settle_operation, links=trace_links): + await self._do_retryable_operation( + self._settle_message, + timeout=None, + message=message, + settle_operation=settle_operation, + dead_letter_reason=dead_letter_reason, + dead_letter_error_description=dead_letter_error_description, + ) + message._settled = True async def _settle_message( # type: ignore self, @@ -502,14 +476,12 @@ async def _settle_message( # type: ignore try: if not message._is_deferred_message: try: - await get_running_loop().run_in_executor( - None, - self._settle_message_via_receiver_link( - message, - settle_operation, - dead_letter_reason=dead_letter_reason, - dead_letter_error_description=dead_letter_error_description, - ), + await self._amqp_transport.settle_message_via_receiver_link_async( + self._handler, + message, + settle_operation, + dead_letter_reason=dead_letter_reason, + dead_letter_error_description=dead_letter_error_description, ) return except RuntimeError as exception: @@ -546,7 +518,7 @@ async def _settle_message_via_mgmt_link( ): message = { MGMT_REQUEST_DISPOSITION_STATUS: settlement, - MGMT_REQUEST_LOCK_TOKENS: types.AMQPArray(lock_tokens), + MGMT_REQUEST_LOCK_TOKENS: self._amqp_transport.AMQP_ARRAY_VALUE(lock_tokens), } self._populate_message_properties(message) @@ -557,9 +529,8 @@ async def _settle_message_via_mgmt_link( REQUEST_RESPONSE_UPDATE_DISPOSTION_OPERATION, message, mgmt_handlers.default ) - async def _renew_locks(self, *lock_tokens, timeout=None): - # type: (str, Optional[float]) -> Any - message = {MGMT_REQUEST_LOCK_TOKENS: types.AMQPArray(lock_tokens)} + async def _renew_locks(self, *lock_tokens: str, timeout: Optional[float] = None) -> Any: + message = {MGMT_REQUEST_LOCK_TOKENS: self._amqp_transport.AMQP_ARRAY_VALUE(lock_tokens)} return await self._mgmt_request_response_with_retry( REQUEST_RESPONSE_RENEWLOCK_OPERATION, message, @@ -618,7 +589,7 @@ def _get_streaming_message_iter( """ if max_wait_time is not None and max_wait_time <= 0: raise ValueError("The max_wait_time must be greater than 0.") - return self._IterContextualWrapper(self, max_wait_time) + return self._iter_contextual_wrapper(max_wait_time) async def receive_messages( self, @@ -662,6 +633,7 @@ async def receive_messages( raise ValueError("The max_wait_time must be greater than 0.") if max_message_count is not None and max_message_count <= 0: raise ValueError("The max_message_count must be greater than 0") + start_time = time.time_ns() messages = await self._do_retryable_operation( self._receive, max_message_count=max_message_count, @@ -669,7 +641,7 @@ async def receive_messages( operation_requires_timeout=True, ) links = get_receive_links(messages) - with receive_trace_context_manager(self, links=links): + with receive_trace_context_manager(self, links=links, start_time=start_time): if ( self._auto_lock_renewer and not self._session @@ -680,11 +652,7 @@ async def receive_messages( return messages async def receive_deferred_messages( - self, - sequence_numbers: Union[int, List[int]], - *, - timeout: Optional[float] = None, - **kwargs: Any, + self, sequence_numbers: Union[int, List[int]], *, timeout: Optional[float] = None, **kwargs: Any ) -> List[ServiceBusReceivedMessage]: """Receive messages that have previously been deferred. @@ -718,16 +686,16 @@ async def receive_deferred_messages( if len(sequence_numbers) == 0: return [] # no-op on empty list. await self._open() - uamqp_receive_mode = ServiceBusToAMQPReceiveModeMap[self._receive_mode] + uamqp_receive_mode = self._amqp_transport.ServiceBusToAMQPReceiveModeMap[self._receive_mode] try: - receive_mode = uamqp_receive_mode.value.value + receive_mode = cast(Enum, uamqp_receive_mode).value except AttributeError: - receive_mode = int(uamqp_receive_mode.value) + receive_mode = int(uamqp_receive_mode) message = { - MGMT_REQUEST_SEQUENCE_NUMBERS: types.AMQPArray( - [types.AMQPLong(s) for s in sequence_numbers] + MGMT_REQUEST_SEQUENCE_NUMBERS: self._amqp_transport.AMQP_ARRAY_VALUE( + [self._amqp_transport.AMQP_LONG_VALUE(s) for s in sequence_numbers] ), - MGMT_REQUEST_RECEIVER_SETTLE_MODE: types.AMQPuInt(receive_mode), + MGMT_REQUEST_RECEIVER_SETTLE_MODE: self._amqp_transport.AMQP_UINT_VALUE(receive_mode), } self._populate_message_properties(message) @@ -735,9 +703,10 @@ async def receive_deferred_messages( handler = functools.partial( mgmt_handlers.deferred_message_op, receive_mode=self._receive_mode, - message_type=ServiceBusReceivedMessage, receiver=self, + amqp_transport=self._amqp_transport, ) + start_time = time.time_ns() messages = await self._mgmt_request_response_with_retry( REQUEST_RESPONSE_RECEIVE_BY_SEQUENCE_NUMBER, message, @@ -746,7 +715,7 @@ async def receive_deferred_messages( ) links = get_receive_links(message) with receive_trace_context_manager( - self, span_name=SPAN_NAME_RECEIVE_DEFERRED, links=links + self, span_name=SPAN_NAME_RECEIVE_DEFERRED, links=links, start_time=start_time ): if ( self._auto_lock_renewer @@ -758,22 +727,15 @@ async def receive_deferred_messages( return messages async def peek_messages( - self, - max_message_count: int = 1, - *, - sequence_number: int = 0, - timeout: Optional[float] = None, - **kwargs: Any, + self, max_message_count: int = 1, *, sequence_number: int = 0, timeout: Optional[float] = None, **kwargs: Any ) -> List[ServiceBusReceivedMessage]: """Browse messages currently pending in the queue. Peeked messages are not removed from queue, nor are they locked. They cannot be completed, deferred or dead-lettered. - For more information about message browsing see https://aka.ms/azsdk/servicebus/message-browsing - - :param int max_message_count: The maximum number of messages to try and peek. The actual number of messages - returned may be fewer and are subject to service limits. The default value is 1. + :param int max_message_count: The maximum number of messages to try and peek. The default + value is 1. :keyword int sequence_number: A message sequence number from which to start browsing messages. :keyword Optional[float] timeout: The total operation timeout in seconds including all the retries. The value must be greater than 0 if specified. The default value is None, meaning no timeout. @@ -801,17 +763,20 @@ async def peek_messages( await self._open() message = { - MGMT_REQUEST_FROM_SEQUENCE_NUMBER: types.AMQPLong(sequence_number), + MGMT_REQUEST_FROM_SEQUENCE_NUMBER: self._amqp_transport.AMQP_LONG_VALUE(sequence_number), MGMT_REQUEST_MAX_MESSAGE_COUNT: max_message_count, } self._populate_message_properties(message) - handler = functools.partial(mgmt_handlers.peek_op, receiver=self) + handler = functools.partial(mgmt_handlers.peek_op, receiver=self, amqp_transport=self._amqp_transport) + start_time = time.time_ns() messages = await self._mgmt_request_response_with_retry( REQUEST_RESPONSE_PEEK_OPERATION, message, handler, timeout=timeout ) links = get_receive_links(message) - with receive_trace_context_manager(self, span_name=SPAN_NAME_PEEK, links=links): + with receive_trace_context_manager( + self, span_name=SPAN_NAME_PEEK, links=links, start_time=start_time + ): return messages async def complete_message(self, message: ServiceBusReceivedMessage) -> None: @@ -891,7 +856,7 @@ async def dead_letter_message( self, message: ServiceBusReceivedMessage, reason: Optional[str] = None, - error_description: Optional[str] = None, + error_description: Optional[str] = None ) -> None: """Move the message to the Dead Letter queue. @@ -926,11 +891,7 @@ async def dead_letter_message( ) async def renew_message_lock( - self, - message: ServiceBusReceivedMessage, - *, - timeout: Optional[float] = None, - **kwargs: Any, + self, message: ServiceBusReceivedMessage, *, timeout: Optional[float] = None, **kwargs: Any ) -> datetime.datetime: # pylint: disable=protected-access,no-member """Renew the message lock. @@ -999,6 +960,4 @@ def client_identifier(self) -> str: return self._name def __str__(self) -> str: - return ( - f"Receiver client id: {self.client_identifier}, entity: {self.entity_path}" - ) + return f"Receiver client id: {self.client_identifier}, entity: {self.entity_path}" diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py index b1d43f0f2a7a9..aede9f10716a4 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py @@ -8,8 +8,6 @@ import warnings from typing import Any, TYPE_CHECKING, Union, List, Optional, Mapping, cast -import uamqp -from uamqp import SendClientAsync, types from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from .._common.message import ( @@ -23,18 +21,32 @@ REQUEST_RESPONSE_SCHEDULE_MESSAGE_OPERATION, REQUEST_RESPONSE_CANCEL_SCHEDULED_MESSAGE_OPERATION, MGMT_REQUEST_SEQUENCE_NUMBERS, - SPAN_NAME_SCHEDULE, + MAX_MESSAGE_LENGTH_BYTES, ) from .._common import mgmt_handlers -from .._common.utils import ( - transform_messages_if_needed, +from .._common.utils import transform_outbound_messages +from .._common.tracing import ( send_trace_context_manager, trace_message, + is_tracing_enabled, + get_span_links_from_batch, + get_span_link_from_message, + SPAN_NAME_SCHEDULE, + TraceAttributes, ) from ._async_utils import create_authentication if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential + try: + # pylint:disable=unused-import + from uamqp.async_ops.client_async import SendClientAsync as uamqp_SendClientAsync + from uamqp.authentication import JWTTokenAsync as uamqp_JWTTokenAuthAsync + except ImportError: + pass + from .._pyamqp.aio import SendClientAsync as pyamqp_SendClientAsync + from .._pyamqp.aio._authentication_async import JWTTokenAuthAsync as pyamqp_JWTTokenAuthAsync + from ._transport._base_async import AmqpTransportAsync MessageTypes = Union[ @@ -102,6 +114,7 @@ def __init__( topic_name: Optional[str] = None, **kwargs: Any, ) -> None: + self._amqp_transport: "AmqpTransportAsync" if kwargs.get("entity_name"): super(ServiceBusSender, self).__init__( fully_qualified_namespace=fully_qualified_namespace, @@ -130,6 +143,7 @@ def __init__( self._max_message_size_on_link = 0 self._create_attribute(**kwargs) self._connection = kwargs.get("connection") + self._handler: Union["pyamqp_SendClientAsync", "uamqp_SendClientAsync"] @classmethod def _from_connection_string( @@ -166,16 +180,17 @@ def _from_connection_string( constructor_args = cls._convert_connection_string_to_kwargs(conn_str, **kwargs) return cls(**constructor_args) - def _create_handler(self, auth): - self._handler = SendClientAsync( - self._entity_uri, + def _create_handler( + self, auth: Union["uamqp_JWTTokenAuthAsync", "pyamqp_JWTTokenAuthAsync"] + ) -> None: + + self._handler = self._amqp_transport.create_send_client_async( + config=self._config, + target=self._entity_uri, auth=auth, - debug=self._config.logging_enable, properties=self._properties, - error_policy=self._error_policy, + retry_policy=self._error_policy, client_name=self._name, - keep_alive_interval=self._config.keep_alive, - encoding=self._config.encoding, ) async def _open(self): @@ -192,21 +207,22 @@ async def _open(self): await asyncio.sleep(0.05) self._running = True self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size - or uamqp.constants.MAX_MESSAGE_LENGTH_BYTES + self._amqp_transport.get_remote_max_message_size(self._handler) + or MAX_MESSAGE_LENGTH_BYTES ) except: await self._close_handler() raise - async def _send(self, message, timeout=None, last_exception=None): - await self._open() - default_timeout = self._handler._msg_timeout # pylint: disable=protected-access - try: - self._set_msg_timeout(timeout, last_exception) - await self._handler.send_message_async(message.message) - finally: # reset the timeout of the handler back to the default value - self._set_msg_timeout(default_timeout, None) + async def _send( + self, + message: Union[ServiceBusMessage, ServiceBusMessageBatch], + timeout: Optional[float] = None, + last_exception: Optional[Exception] = None + ) -> None: + await self._amqp_transport.send_messages_async( + self, message, _LOGGER, timeout=timeout, last_exception=last_exception + ) async def schedule_messages( self, @@ -242,22 +258,33 @@ async def schedule_messages( # pylint: disable=protected-access self._check_live() - obj_messages = transform_messages_if_needed(messages, ServiceBusMessage) + obj_messages = transform_outbound_messages( + messages, ServiceBusMessage, to_outgoing_amqp_message=self._amqp_transport.to_outgoing_amqp_message + ) if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") - with send_trace_context_manager(span_name=SPAN_NAME_SCHEDULE) as send_span: - if isinstance(obj_messages, ServiceBusMessage): - request_body = self._build_schedule_request( - schedule_time_utc, send_span, obj_messages - ) - else: - if len(obj_messages) == 0: - return [] # No-op on empty list. - request_body = self._build_schedule_request( - schedule_time_utc, send_span, *obj_messages - ) - if send_span: - self._add_span_request_attributes(send_span) + + tracing_attributes = { + TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE: self.fully_qualified_namespace, + TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE: self.entity_name, + } + if isinstance(obj_messages, ServiceBusMessage): + request_body, trace_links = self._build_schedule_request( + schedule_time_utc, + self._amqp_transport, + tracing_attributes, + obj_messages + ) + else: + if len(obj_messages) == 0: + return [] # No-op on empty list. + request_body, trace_links = self._build_schedule_request( + schedule_time_utc, + self._amqp_transport, + tracing_attributes, + *obj_messages + ) + with send_trace_context_manager(self, span_name=SPAN_NAME_SCHEDULE, links=trace_links): return await self._mgmt_request_response_with_retry( REQUEST_RESPONSE_SCHEDULE_MESSAGE_OPERATION, request_body, @@ -298,12 +325,12 @@ async def cancel_scheduled_messages( if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") if isinstance(sequence_numbers, int): - numbers = [types.AMQPLong(sequence_numbers)] + numbers = [self._amqp_transport.AMQP_LONG_VALUE(sequence_numbers)] else: - numbers = [types.AMQPLong(s) for s in sequence_numbers] + numbers = [self._amqp_transport.AMQP_LONG_VALUE(s) for s in sequence_numbers] if len(numbers) == 0: return None # no-op on empty list. - request_body = {MGMT_REQUEST_SEQUENCE_NUMBERS: types.AMQPArray(numbers)} + request_body = {MGMT_REQUEST_SEQUENCE_NUMBERS: self._amqp_transport.AMQP_ARRAY_VALUE(numbers)} return await self._mgmt_request_response_with_retry( REQUEST_RESPONSE_CANCEL_SCHEDULED_MESSAGE_OPERATION, request_body, @@ -355,28 +382,51 @@ async def send_messages( if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") - with send_trace_context_manager() as send_span: - if isinstance(message, ServiceBusMessageBatch): - obj_message = message # type: MessageObjTypes + try: # Short circuit noop if an empty list or batch is provided. + if len(cast(Union[List, ServiceBusMessageBatch], message)) == 0: # pylint: disable=len-as-condition + return + except TypeError: # continue if ServiceBusMessage + pass + + obj_message: Union[ServiceBusMessage, ServiceBusMessageBatch] + if isinstance(message, ServiceBusMessageBatch): + # If AmqpTransports are not the same, create batch with correct BatchMessage. + if self._amqp_transport is not message._amqp_transport: # pylint: disable=protected-access + # pylint: disable=protected-access + batch = await self.create_message_batch() + batch._from_list(message._messages) # type: ignore + obj_message = batch else: - obj_message = transform_messages_if_needed( # type: ignore - message, ServiceBusMessage + obj_message = message + else: + obj_message = transform_outbound_messages( # type: ignore + message, ServiceBusMessage, self._amqp_transport.to_outgoing_amqp_message + ) + try: + batch = await self.create_message_batch() + batch._from_list(obj_message) # type: ignore # pylint: disable=protected-access + obj_message = batch + except TypeError: # Message was not a list or generator. + # pylint: disable=protected-access + obj_message._message = trace_message( + obj_message._message, + amqp_transport=self._amqp_transport, + additional_attributes={ + TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE: self.fully_qualified_namespace, + TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE: self.entity_name, + } ) - try: - batch = await self.create_message_batch() - batch._from_list(obj_message, send_span) # type: ignore # pylint: disable=protected-access - obj_message = batch - except TypeError: # Message was not a list or generator. - trace_message(cast(ServiceBusMessage, obj_message), send_span) - if ( - isinstance(obj_message, ServiceBusMessageBatch) - and len(obj_message) == 0 - ): # pylint: disable=len-as-condition - return # Short circuit noop if an empty list or batch is provided. - - if send_span: - self._add_span_request_attributes(send_span) + trace_links = [] + if is_tracing_enabled(): + if isinstance(obj_message, ServiceBusMessageBatch): + trace_links = get_span_links_from_batch(obj_message) + else: + link = get_span_link_from_message(obj_message._message) # pylint: disable=protected-access + if link: + trace_links.append(link) + + with send_trace_context_manager(self, links=trace_links): await self._do_retryable_operation( self._send, message=obj_message, @@ -411,13 +461,17 @@ async def create_message_batch( if max_size_in_bytes and max_size_in_bytes > self._max_message_size_on_link: raise ValueError( - "Max message size: {} is too large, acceptable max batch size is: {} bytes.".format( - max_size_in_bytes, self._max_message_size_on_link - ) + f"Max message size: {max_size_in_bytes} is too large, " + "acceptable max batch size is: {self._max_message_size_on_link} bytes." ) return ServiceBusMessageBatch( - max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link) + max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), + amqp_transport=self._amqp_transport, + tracing_attributes = { + TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE: self.fully_qualified_namespace, + TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE: self.entity_name, + } ) @property diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_async.py index e9f2ab23324b0..49f1e298df931 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_async.py @@ -77,7 +77,7 @@ async def set_state( ) -> None: """Set the session state. - :param state: The state value. Setting state to None will clear the current session. + :param state: The state value. :type state: Union[str, bytes, bytearray, None] :keyword Optional[float] timeout: The total operation timeout in seconds including all the retries. The value must be greater than 0 if specified. The default value is None, meaning no timeout. diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/__init__.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/__init__.py new file mode 100644 index 0000000000000..34913fb394d7a --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_base_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_base_async.py new file mode 100644 index 0000000000000..7463244e7ad8d --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_base_async.py @@ -0,0 +1,295 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Tuple, Union, TYPE_CHECKING, Any, Dict, Callable +from typing_extensions import Literal + +if TYPE_CHECKING: + try: + from uamqp import types as uamqp_types + except ImportError: + uamqp_types = None + +class AmqpTransportAsync(ABC): # pylint: disable=too-many-public-methods + """ + Abstract class that defines a set of common methods needed by sender and receiver. + """ + KIND: str + + # define constants + MAX_FRAME_SIZE_BYTES: int + MAX_MESSAGE_LENGTH_BYTES: int + TIMEOUT_FACTOR: int + CONNECTION_CLOSING_STATES: Tuple + + ServiceBusToAMQPReceiveModeMap: Dict[str, Any] + + # define symbols + PRODUCT_SYMBOL: Union[uamqp_types.AMQPSymbol, Literal["product"]] + VERSION_SYMBOL: Union[uamqp_types.AMQPSymbol, Literal["version"]] + FRAMEWORK_SYMBOL: Union[uamqp_types.AMQPSymbol, Literal["framework"]] + PLATFORM_SYMBOL: Union[uamqp_types.AMQPSymbol, Literal["platform"]] + USER_AGENT_SYMBOL: Union[uamqp_types.AMQPSymbol, Literal["user-agent"]] + PROP_PARTITION_KEY_AMQP_SYMBOL: Union[uamqp_types.AMQPSymbol, Literal[b'x-opt-partition-key']] + AMQP_LONG_VALUE: Callable + AMQP_ARRAY_VALUE: Callable + AMQP_UINT_VALUE: Callable + + @staticmethod + @abstractmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def build_batch_message(data): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ + + @staticmethod + @abstractmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def update_message_app_properties(message, key, value): + """ + Adds the given key/value to the application properties of the message. + :param uamqp.Message or pyamqp.Message message: Message. + :param str key: Key to set in application properties. + :param str Value: Value to set for key in application properties. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + + @staticmethod + @abstractmethod + def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @staticmethod + @abstractmethod + def create_retry_policy(config): + """ + Creates the error retry policy. + :param Configuration config: Configuration. + """ + + @staticmethod + @abstractmethod + async def create_connection_async(host, auth, network_trace, **kwargs): + """ + Creates and returns the pyamqp Connection object. + :param str host: The hostname used by pyamqp. + :param JWTTokenAuth auth: The auth used by pyamqp. + :param bool network_trace: Debug setting. + """ + + @staticmethod + @abstractmethod + async def close_connection_async(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def create_send_client_async(config, **kwargs): + """ + Creates and returns the send client. + :param Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @staticmethod + @abstractmethod + async def send_messages_async(sender, message, logger, timeout, last_exception): + """ + Handles sending of service bus messages. + :param sender: The sender with handler to send messages. + :param int timeout: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + @staticmethod + @abstractmethod + def create_source(source, session_filter): + """ + Creates and returns the Source. + + :param Source source: Required. + :param str or None session_id: Required. + """ + + @staticmethod + @abstractmethod + def create_receive_client_async(receiver, **kwargs): + """ + Creates and returns the receive client. + :param Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + @staticmethod + @abstractmethod + async def iter_contextual_wrapper_async( + receiver, max_wait_time=None + ): + """The purpose of this wrapper is to allow both state restoration (for multiple concurrent iteration) + and per-iter argument passing that requires the former.""" + + @staticmethod + @abstractmethod + async def iter_next_async( + receiver, wait_time=None + ): + """ + Used to iterate through received messages. + """ + + @staticmethod + @abstractmethod + def build_received_message(receiver, message_type, received): + """ + Build ServiceBusReceivedMessage. + """ + + @staticmethod + @abstractmethod + def set_handler_message_received_async(receiver): + """ + Sets _message_received on async handler. + """ + + @staticmethod + @abstractmethod + def get_current_time(handler): + """ + Gets the current time. + """ + + @staticmethod + @abstractmethod + async def reset_link_credit_async( + handler, link_credit + ): + """ + Resets the link credit on the link. + :param ReceiveClientAsync handler: Client with link to reset link credit. + :param int link_credit: Link credit needed. + :rtype: None + """ + + @staticmethod + @abstractmethod + async def settle_message_via_receiver_link_async( + handler, + message, + settle_operation, + dead_letter_reason=None, + dead_letter_error_description=None, + ) -> None: + """ + Settles message. + """ + + @staticmethod + @abstractmethod + def parse_received_message(message, message_type, **kwargs): + """ + Parses peek/deferred op messages into ServiceBusReceivedMessage. + :param Message message: Message to parse. + :param ServiceBusReceivedMessage message_type: Parse messages to return. + :keyword ServiceBusReceiver receiver: Required. + :keyword bool is_peeked_message: Optional. For peeked messages. + :keyword bool is_deferred_message: Optional. For deferred messages. + :keyword ServiceBusReceiveMode receive_mode: Optional. + """ + + @staticmethod + @abstractmethod + async def create_token_auth_async(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @staticmethod + @abstractmethod + async def mgmt_client_request_async( + mgmt_client, + mgmt_msg, + *, + operation, + operation_type, + node, + timeout, + callback + ): + """ + Send mgmt request. + :param AMQPClient mgmt_client: Client to send request with. + :param Message mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword bytes operation_type: Op type. + :keyword bytes node: Mgmt target. + :keyword int timeout: Timeout. + :keyword Callable callback: Callback to process request response. + """ diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py new file mode 100644 index 0000000000000..eec305cb1667f --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py @@ -0,0 +1,384 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from __future__ import annotations +import functools +from typing import TYPE_CHECKING, Optional, Any, Callable, Union, AsyncIterator, cast +import time + +from ..._pyamqp import constants +from ..._pyamqp.message import BatchMessage +from ..._pyamqp.utils import amqp_string_value +from ..._pyamqp.aio import SendClientAsync, ReceiveClientAsync +from ..._pyamqp.aio._authentication_async import JWTTokenAuthAsync +from ..._pyamqp.aio._connection_async import Connection as ConnectionAsync +from ..._pyamqp.error import ( + AMQPError, + MessageException, +) + +from ._base_async import AmqpTransportAsync +from ..._common.utils import utc_from_timestamp, utc_now +from ..._common.tracing import get_receive_links, receive_trace_context_manager +from ..._common.constants import ( + DATETIMEOFFSET_EPOCH, + SESSION_LOCKED_UNTIL, + SESSION_FILTER, + RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION, + RECEIVER_LINK_DEAD_LETTER_REASON, + DEADLETTERNAME, + MESSAGE_COMPLETE, + MESSAGE_ABANDON, + MESSAGE_DEFER, + MESSAGE_DEAD_LETTER, + ServiceBusReceiveMode, +) +from ..._transport._pyamqp_transport import PyamqpTransport +from ...exceptions import ( + OperationTimeoutError +) + +if TYPE_CHECKING: + from logging import Logger + from ..._common.message import ServiceBusReceivedMessage, ServiceBusMessage, ServiceBusMessageBatch + from ..._common._configuration import Configuration + from .._servicebus_receiver_async import ServiceBusReceiver + from .._servicebus_sender_async import ServiceBusSender + from ..._pyamqp.performatives import AttachFrame + from ..._pyamqp.message import Message + from ..._pyamqp.aio._client_async import AMQPClientAsync + +class PyamqpTransportAsync(PyamqpTransport, AmqpTransportAsync): + """ + Class which defines pyamqp-based methods used by the sender and receiver. + """ + + @staticmethod + async def create_connection_async( + host: str, auth: "JWTTokenAuthAsync", network_trace: bool, **kwargs: Any + ) -> "ConnectionAsync": + """ + Creates and returns the pyamqp Connection object. + :param str host: The hostname used by pyamqp. + :param JWTTokenAuth auth: The auth used by pyamqp. + :param bool network_trace: Debug setting. + """ + return ConnectionAsync( + endpoint=host, + sasl_credential=auth.sasl, + network_trace=network_trace, + **kwargs + ) + + @staticmethod + async def close_connection_async(connection: "ConnectionAsync") -> None: + """ + Closes existing connection. + :param connection: pyamqp Connection. + """ + await connection.close() + + @staticmethod + def create_send_client_async( + config: "Configuration", **kwargs: Any + ) -> "SendClientAsync": + """ + Creates and returns the pyamqp SendClient. + :param Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + return SendClientAsync( + config.hostname, + target, + network_trace=config.logging_enable, + keep_alive_interval=config.keep_alive, + custom_endpoint_address=config.custom_endpoint_address, + connection_verify=config.connection_verify, + transport_type=config.transport_type, + http_proxy=config.http_proxy, + **kwargs, + ) + + @staticmethod + async def send_messages_async( + sender: "ServiceBusSender", + message: Union["ServiceBusMessage", "ServiceBusMessageBatch"], + logger: "Logger", + timeout: int, + last_exception: Optional[Exception] + ) -> None: # pylint: disable=unused-argument + """ + Handles sending of service bus messages. + :param sender: The sender with handler to send messages. + :param int timeout: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + await sender._open() + try: + if isinstance(message._message, list): + await sender._handler.send_message_async(BatchMessage(*message._message), timeout=timeout) + else: + await sender._handler.send_message_async( + message._message, + timeout=timeout + ) + except TimeoutError: + raise OperationTimeoutError(message="Send operation timed out") + except MessageException as e: + raise PyamqpTransportAsync.create_servicebus_exception(logger, e) + + @staticmethod + def create_receive_client_async( + receiver: "ServiceBusReceiver", **kwargs: Any + ) -> "ReceiveClientAsync": # pylint:disable=unused-argument + """ + Creates and returns the receive client. + :param Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword timeout: Required. + """ + config = receiver._config # pylint: disable=protected-access + source = kwargs.pop("source") + receive_mode = kwargs.pop("receive_mode") + + return ReceiveClientAsync( + config.hostname, + source, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_address=config.custom_endpoint_address, + connection_verify=config.connection_verify, + receive_settle_mode=PyamqpTransportAsync.ServiceBusToAMQPReceiveModeMap[receive_mode], + send_settle_mode=constants.SenderSettleMode.Settled + if receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE + else constants.SenderSettleMode.Unsettled, + on_attach=functools.partial( + PyamqpTransportAsync.on_attach_async, + receiver + ), + **kwargs, + ) + + @staticmethod + async def iter_contextual_wrapper_async( + receiver: "ServiceBusReceiver", max_wait_time: Optional[int] = None + ) -> AsyncIterator["ServiceBusReceivedMessage"]: + while True: + try: + # pylint: disable=protected-access + message = await receiver._inner_anext(wait_time=max_wait_time) + links = get_receive_links(message) + with receive_trace_context_manager(receiver, links=links): + yield message + except StopAsyncIteration: + break + + @staticmethod + async def iter_next_async( + receiver: "ServiceBusReceiver", wait_time: Optional[int] = None + ) -> "ServiceBusReceivedMessage": + # pylint: disable=protected-access + try: + receiver._receive_context.set() + await receiver._open() + if not receiver._message_iter or wait_time: + receiver._message_iter = await receiver._handler.receive_messages_iter_async(timeout=wait_time) + pyamqp_message = await cast(AsyncIterator["Message"], receiver._message_iter).__anext__() + message = receiver._build_received_message(pyamqp_message) + if ( + receiver._auto_lock_renewer + and not receiver._session + and receiver._receive_mode != ServiceBusReceiveMode.RECEIVE_AND_DELETE + ): + receiver._auto_lock_renewer.register(receiver, message) + return message + finally: + receiver._receive_context.clear() + + @staticmethod + async def enhanced_message_received_async( + receiver: "ServiceBusReceiver", + frame: "AttachFrame", + message: "Message" + ) -> None: + # pylint: disable=protected-access + receiver._handler._last_activity_timestamp = time.time() + if receiver._receive_context.is_set(): + receiver._handler._received_messages.put((frame, message)) + else: + await receiver._handler.settle_messages_async(frame[1], 'released') + + @staticmethod + def set_handler_message_received_async(receiver: "ServiceBusReceiver") -> None: + # reassigning default _message_received method in ReceiveClient + # pylint: disable=protected-access + receiver._handler._message_received_async = functools.partial( # type: ignore[assignment] + PyamqpTransportAsync.enhanced_message_received_async, + receiver + ) + + @staticmethod + async def reset_link_credit_async( + handler: "ReceiveClientAsync", link_credit: int + ) -> None: + """ + Resets the link credit on the link. + :param ReceiveClientAsync handler: Client with link to reset link credit. + :param int link_credit: Link credit needed. + :rtype: None + """ + await handler._link.flow(link_credit=link_credit) # pylint: disable=protected-access + + @staticmethod + async def settle_message_via_receiver_link_async( + handler: "ReceiveClientAsync", + message: "ServiceBusReceivedMessage", + settle_operation: str, + dead_letter_reason: Optional[str] = None, + dead_letter_error_description: Optional[str] = None, + ) -> None: + # pylint: disable=protected-access + if settle_operation == MESSAGE_COMPLETE: + return await handler.settle_messages_async(message._delivery_id, 'accepted') + if settle_operation == MESSAGE_ABANDON: + return await handler.settle_messages_async( + message._delivery_id, + 'modified', + delivery_failed=True, + undeliverable_here=False + ) + if settle_operation == MESSAGE_DEAD_LETTER: + return await handler.settle_messages_async( + message._delivery_id, + 'rejected', + error=AMQPError( + condition=DEADLETTERNAME, + description=dead_letter_error_description, + info={ + RECEIVER_LINK_DEAD_LETTER_REASON: dead_letter_reason, + RECEIVER_LINK_DEAD_LETTER_ERROR_DESCRIPTION: dead_letter_error_description, + } + ) + ) + if settle_operation == MESSAGE_DEFER: + return await handler.settle_messages_async( + message._delivery_id, + 'modified', + delivery_failed=True, + undeliverable_here=True + ) + raise ValueError( + f"Unsupported settle operation type: {settle_operation}" + ) + + @staticmethod + async def on_attach_async( + receiver: "ServiceBusReceiver", attach_frame: "AttachFrame" + ) -> None: + # pylint: disable=protected-access, unused-argument + if receiver._session and attach_frame.source.address.decode() == receiver._entity_uri: + # This has to live on the session object so that autorenew has access to it. + receiver._session._session_start = utc_now() + expiry_in_seconds = attach_frame.properties.get(SESSION_LOCKED_UNTIL) + if expiry_in_seconds: + expiry_in_seconds = ( + expiry_in_seconds - DATETIMEOFFSET_EPOCH + ) / 10000000 + receiver._session._locked_until_utc = utc_from_timestamp(expiry_in_seconds) + session_filter = attach_frame.source.filters[SESSION_FILTER] + receiver._session_id = session_filter.decode(receiver._config.encoding) + receiver._session._session_id = receiver._session_id + + @staticmethod + async def create_token_auth_async( + auth_uri: str, + get_token: Callable, + token_type: bytes, + config: "Configuration", + **kwargs: Any + ) -> "JWTTokenAuthAsync": + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + # TODO: figure out why we're passing all these args to pyamqp JWTTokenAuth, which aren't being used + update_token = kwargs.pop("update_token") # pylint: disable=unused-variable + if update_token: + # update_token not actually needed by pyamqp + # just using to detect which args to pass + return JWTTokenAuthAsync(auth_uri, auth_uri, get_token) + return JWTTokenAuthAsync( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + ) + + @staticmethod + async def mgmt_client_request_async( + mgmt_client: "AMQPClientAsync", + mgmt_msg: "Message", + *, + operation: bytes, + operation_type: bytes, + node: bytes, + timeout: int, + callback: Callable + ) -> "ServiceBusReceivedMessage": + """ + Send mgmt request. + :param AMQPClient mgmt_client: Client to send request with. + :param Message mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword bytes operation_type: Op type. + :keyword bytes node: Mgmt target. + :keyword int timeout: Timeout. + :keyword Callable callback: Callback to process request response. + """ + status, description, response = await mgmt_client.mgmt_request_async( + mgmt_msg, + operation=amqp_string_value(operation.decode("UTF-8")), + operation_type=amqp_string_value(operation_type), + node=node, + timeout=timeout, # TODO: check if this should be seconds * 1000 if timeout else None, + ) + return callback(status, response, description, amqp_transport=PyamqpTransportAsync) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_uamqp_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_uamqp_transport_async.py new file mode 100644 index 0000000000000..b11babed9f857 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_uamqp_transport_async.py @@ -0,0 +1,332 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from __future__ import annotations +import functools +from typing import TYPE_CHECKING, Optional, Any, Callable, Union, AsyncIterator, cast + +try: + from uamqp import ( + constants, + SendClientAsync, + ReceiveClientAsync, + ) + from uamqp.authentication import JWTTokenAsync as JWTTokenAuthAsync + from uamqp.async_ops import ConnectionAsync + from ..._transport._uamqp_transport import UamqpTransport + from ._base_async import AmqpTransportAsync + from .._async_utils import get_running_loop + from ..._common.tracing import get_receive_links, receive_trace_context_manager + from ..._common.constants import ServiceBusReceiveMode + + if TYPE_CHECKING: + from uamqp import AMQPClientAsync, Message + from logging import Logger + from .._servicebus_receiver_async import ServiceBusReceiver + from .._servicebus_sender_async import ServiceBusSender + from ..._common.message import ServiceBusReceivedMessage, ServiceBusMessage, ServiceBusMessageBatch + from ..._common._configuration import Configuration + + + class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): + """ + Class which defines uamqp-based methods used by the sender and receiver. + """ + + @staticmethod + async def create_connection_async( + host: str, auth: "JWTTokenAuthAsync", network_trace: bool, **kwargs: Any + ) -> "ConnectionAsync": + """ + Creates and returns the uamqp Connection object. + :param str host: The hostname, used by uamqp. + :param JWTTokenAuth auth: The auth, used by uamqp. + :param bool network_trace: Required. + """ + custom_endpoint_address = kwargs.pop("custom_endpoint_address") # pylint:disable=unused-variable + ssl_opts = kwargs.pop("ssl_opts") # pylint:disable=unused-variable + transport_type = kwargs.pop("transport_type") # pylint:disable=unused-variable + http_proxy = kwargs.pop("http_proxy") # pylint:disable=unused-variable + return ConnectionAsync( + hostname=host, + sasl=auth, + debug=network_trace, + ) + + @staticmethod + async def close_connection_async(connection: "ConnectionAsync") -> None: + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + await connection.destroy_async() + + @staticmethod + def create_send_client_async( + config: "Configuration", **kwargs: Any + ) -> "SendClientAsync": + """ + Creates and returns the uamqp SendClient. + :param Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + + return SendClientAsync( + target, + debug=config.logging_enable, + error_policy=retry_policy, + keep_alive_interval=config.keep_alive, + encoding=config.encoding, + **kwargs + ) + + @staticmethod + async def send_messages_async( + sender: "ServiceBusSender", + message: Union["ServiceBusMessage", "ServiceBusMessageBatch"], + logger: "Logger", + timeout: int, + last_exception: Optional[Exception] + ) -> None: + """ + Handles sending of service bus messages. + :param sender: The sender with handler to send messages. + :param message: ServiceBusMessage with uamqp.Message to be sent. + :paramtype message: ~azure.servicebus.ServiceBusMessage or ~azure.servicebus.ServiceBusMessageBatch + :param int timeout: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + await sender._open() + default_timeout = cast("SendClientAsync", sender._handler)._msg_timeout + try: + UamqpTransportAsync.set_msg_timeout(sender, logger, timeout, last_exception) + await cast("SendClientAsync", sender._handler).send_message_async(message._message) + finally: # reset the timeout of )the handler back to the default value + UamqpTransportAsync.set_msg_timeout(sender, logger, default_timeout, None) + + @staticmethod + def create_receive_client_async( + receiver: "ServiceBusReceiver", **kwargs: Any + ) -> "ReceiveClientAsync": + """ + Creates and returns the receive client. + :param Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword timeout: Required. + """ + source = kwargs.pop("source") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + receive_mode = kwargs.pop("receive_mode") + + return ReceiveClientAsync( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + prefetch=link_credit, + auto_complete=False, + receive_settle_mode=UamqpTransportAsync.ServiceBusToAMQPReceiveModeMap[receive_mode], + send_settle_mode=constants.SenderSettleMode.Settled + if receive_mode == ServiceBusReceiveMode.RECEIVE_AND_DELETE + else None, + on_attach=functools.partial( + UamqpTransportAsync.on_attach, + receiver + ), + **kwargs + ) + + @staticmethod + async def iter_contextual_wrapper_async( + receiver: "ServiceBusReceiver", max_wait_time: Optional[int] = None + ) -> AsyncIterator["ServiceBusReceivedMessage"]: + """The purpose of this wrapper is to allow both state restoration (for multiple concurrent iteration) + and per-iter argument passing that requires the former.""" + # pylint: disable=protected-access + original_timeout = None + while True: + # This is not threadsafe, but gives us a way to handle if someone passes + # different max_wait_times to different iterators and uses them in concert. + if max_wait_time: + original_timeout = receiver._handler._timeout + receiver._handler._timeout = max_wait_time * UamqpTransport.TIMEOUT_FACTOR + try: + message = await receiver._inner_anext() + links = get_receive_links(message) + with receive_trace_context_manager(receiver, links=links): + yield message + except StopAsyncIteration: + break + finally: + if original_timeout: + try: + receiver._handler._timeout = original_timeout + except AttributeError: # Handler may be disposed already. + pass + + # wait_time used by pyamqp + @staticmethod + async def iter_next_async( + receiver: "ServiceBusReceiver", wait_time: Optional[int] = None + ) -> "ServiceBusReceivedMessage": # pylint: disable=unused-argument + # pylint: disable=protected-access + try: + receiver._receive_context.set() + await receiver._open() + if not receiver._message_iter: + receiver._message_iter = receiver._handler.receive_messages_iter_async() + uamqp_message = await cast(AsyncIterator["Message"], receiver._message_iter).__anext__() + message = receiver._build_received_message(uamqp_message) + if ( + receiver._auto_lock_renewer + and not receiver._session + and receiver._receive_mode != ServiceBusReceiveMode.RECEIVE_AND_DELETE + ): + receiver._auto_lock_renewer.register(receiver, message) + return message + finally: + receiver._receive_context.clear() + + # called by async ServiceBusReceiver + enhanced_message_received_async = UamqpTransport.enhanced_message_received + + @staticmethod + def set_handler_message_received_async(receiver: "ServiceBusReceiver") -> None: + # reassigning default _message_received method in ReceiveClient + # pylint: disable=protected-access + receiver._handler._message_received = functools.partial( # type: ignore[assignment] + UamqpTransportAsync.enhanced_message_received_async, + receiver + ) + + @staticmethod + async def reset_link_credit_async( + handler: "ReceiveClientAsync", link_credit: int + ) -> None: + """ + Resets the link credit on the link. + :param ReceiveClientAsync handler: Client with link to reset link credit. + :param int link_credit: Link credit needed. + :rtype: None + """ + await handler.message_handler.reset_link_credit_async(link_credit) + + @staticmethod + async def settle_message_via_receiver_link_async( + handler: "ReceiveClientAsync", + message: "ServiceBusReceivedMessage", + settle_operation: str, + dead_letter_reason: Optional[str] = None, + dead_letter_error_description: Optional[str] = None, + ) -> None: # pylint: disable=unused-argument + await get_running_loop().run_in_executor( + None, + UamqpTransportAsync.settle_message_via_receiver_link_impl( + handler, + message, + settle_operation, + dead_letter_reason, + dead_letter_error_description + ), + ) + + @staticmethod + async def create_token_auth_async( + auth_uri: str, + get_token: Callable, + token_type: bytes, + config: "Configuration", + **kwargs: Any + ) -> "JWTTokenAuthAsync": + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 0 if update_token else 300 + + token_auth = JWTTokenAuthAsync( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + await token_auth.update_token() + return token_auth + + @staticmethod + async def mgmt_client_request_async( + mgmt_client: "AMQPClientAsync", + mgmt_msg: "Message", + *, + operation: bytes, + operation_type: bytes, + node: bytes, + timeout: int, + callback: Callable + ) -> "ServiceBusReceivedMessage": + """ + Send mgmt request. + :param AMQPClient mgmt_client: Client to send request with. + :param Message mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword bytes operation_type: Op type. + :keyword bytes node: Mgmt target. + :keyword int timeout: Timeout. + :keyword Callable callback: Callback to process request response. + """ + return await mgmt_client.mgmt_request_async( + mgmt_msg, + operation, + op_type=operation_type, + node=node, + timeout=timeout * UamqpTransportAsync.TIMEOUT_FACTOR if timeout else None, + callback=functools.partial(callback, amqp_transport=UamqpTransportAsync) + ) +except ImportError: + pass diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py index c16816d9e8d11..952735619f17c 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py @@ -225,7 +225,7 @@ async def _create_forward_to_header_tokens(self, entity, kwargs): kwargs["headers"] = kwargs.get("headers", {}) async def _populate_header_within_kwargs(uri, header): - token = (await self._credential.get_token(uri)).token.decode() + token = (await self._credential.get_token(uri)).token if not isinstance( self._credential, (ServiceBusSASTokenCredential, ServiceBusSharedKeyCredential), @@ -565,7 +565,7 @@ async def get_topic(self, topic_name: str, **kwargs: Any) -> TopicProperties: return topic_description async def get_topic_runtime_properties( - self, topic_name: str, **kwargs + self, topic_name: str, **kwargs: Any ) -> TopicRuntimeProperties: """Get the runtime information of a topic. @@ -810,8 +810,7 @@ async def get_subscription( ) ) subscription = SubscriptionProperties._from_internal_entity( - subscription_name, - entry.content.subscription_description + subscription_name, entry.content.subscription_description ) return subscription @@ -835,8 +834,7 @@ async def get_subscription_runtime_properties( ) ) subscription = SubscriptionRuntimeProperties._from_internal_entity( - subscription_name, - entry.content.subscription_description + subscription_name, entry.content.subscription_description ) return subscription @@ -957,8 +955,7 @@ async def create_subscription( # since we know for certain that `entry.content` will not be None here. entry.content = cast(SubscriptionDescriptionEntryContent, entry.content) result = SubscriptionProperties._from_internal_entity( - subscription_name, - entry.content.subscription_description + subscription_name, entry.content.subscription_description ) return result @@ -1269,10 +1266,11 @@ async def get_namespace_properties(self, **kwargs: Any) -> NamespaceProperties: """ entry_el = await self._impl.namespace.get(**kwargs) namespace_entry = NamespacePropertiesEntry.deserialize(entry_el) - namespace_entry.content = cast(NamespacePropertiesEntryContent, namespace_entry.content) + namespace_entry.content = cast( + NamespacePropertiesEntryContent, namespace_entry.content + ) return NamespaceProperties._from_internal_entity( - namespace_entry.title, - namespace_entry.content.namespace_properties + namespace_entry.title, namespace_entry.content.namespace_properties ) async def close(self) -> None: diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_shared_key_policy_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_shared_key_policy_async.py index a457d5c7d6fd4..a948bb143440c 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_shared_key_policy_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_shared_key_policy_async.py @@ -34,7 +34,7 @@ async def _update_token(self): # pylint: disable=invalid-overridden-method access_token, self._token_expiry_on = await self._credential.get_token( self._endpoint ) - self._token = access_token.decode("utf-8") + self._token = access_token async def on_request( self, request: PipelineRequest diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_message.py b/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_message.py index 6e3293e3921f7..d4de6efadd172 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_message.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_message.py @@ -5,18 +5,14 @@ # ------------------------------------------------------------------------- from __future__ import annotations -import time -import uuid -from datetime import datetime import warnings -from typing import Optional, Any, Tuple, cast, Mapping, Union, Dict, List +from typing import Optional, Any, cast, Mapping, Dict, Union, List, Iterable, Tuple, TYPE_CHECKING -from datetime import timezone -import uamqp - -from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType -from .._common.constants import MAX_DURATION_VALUE, MAX_ABSOLUTE_EXPIRY_TIME +from ._amqp_utils import normalized_data_body, normalized_sequence_body +from ._constants import AmqpMessageBodyType +if TYPE_CHECKING: + import uuid class DictMixin(object): def __setitem__(self, key: str, item: Any) -> None: @@ -75,7 +71,6 @@ class AmqpAnnotatedMessage(object): access to low-level AMQP message sections. There should be one and only one of either data_body, sequence_body or value_body being set as the body of the AmqpAnnotatedMessage; if more than one body is set, `ValueError` will be raised. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#section-message-format for more information on the message format. @@ -112,12 +107,15 @@ def __init__( delivery_annotations: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> None: - self._message = kwargs.pop("message", None) self._encoding = kwargs.pop("encoding", "UTF-8") + self._data_body: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None + self._sequence_body: Optional[List[Any]] = None + self._value_body: Any = None # internal usage only for service bus received message - if self._message: - self._from_amqp_message(self._message) + message = kwargs.pop("message", None) + if message: + self._from_amqp_message(message) return # manually constructed AMQPAnnotatedMessage @@ -128,37 +126,41 @@ def __init__( "or value_body being set as the body of the AmqpAnnotatedMessage." ) - self._body = None - self._body_type = None + self._body_type: AmqpMessageBodyType = None # type: ignore if "data_body" in kwargs: - self._body = kwargs.get("data_body") - self._body_type = uamqp.MessageBodyType.Data + self._data_body = normalized_data_body(kwargs.get("data_body")) + self._body_type = AmqpMessageBodyType.DATA elif "sequence_body" in kwargs: - self._body = kwargs.get("sequence_body") - self._body_type = uamqp.MessageBodyType.Sequence + self._sequence_body = normalized_sequence_body(kwargs.get("sequence_body")) + self._body_type = AmqpMessageBodyType.SEQUENCE elif "value_body" in kwargs: - self._body = kwargs.get("value_body") - self._body_type = uamqp.MessageBodyType.Value + self._value_body = kwargs.get("value_body") + self._body_type = AmqpMessageBodyType.VALUE - self._message = uamqp.message.Message(body=self._body, body_type=self._body_type) header_dict = cast(Mapping, header) self._header = AmqpMessageHeader(**header_dict) if header else None self._footer = footer properties_dict = cast(Mapping, properties) self._properties = AmqpMessageProperties(**properties_dict) if properties else None - self._application_properties = application_properties - self._annotations = annotations - self._delivery_annotations = delivery_annotations + self._application_properties = cast(Optional[Dict[Union[str, bytes], Any]], application_properties) + self._annotations = cast(Optional[Dict[Union[str, bytes], Any]], annotations) + self._delivery_annotations = cast(Optional[Dict[Union[str, bytes], Any]], delivery_annotations) def __str__(self) -> str: - return str(self._message) + if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return + return "".join(d.decode(self._encoding) for d in cast(Iterable[bytes], self._data_body)) + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + return str(self._sequence_body) + elif self._body_type == AmqpMessageBodyType.VALUE: + return str(self._value_body) + return "" def __repr__(self) -> str: # pylint: disable=bare-except message_repr = "body={}".format( str(self) ) - message_repr += ", body_type={}".format(self.body_type) + message_repr += ", body_type={}".format(self._body_type.value) try: message_repr += ", header={}".format(self.header) except: @@ -186,7 +188,6 @@ def __repr__(self) -> str: return "AmqpAnnotatedMessage({})".format(message_repr)[:1024] def _from_amqp_message(self, message): - # populate the properties from an uamqp message self._properties = AmqpMessageProperties( message_id=message.properties.message_id, user_id=message.properties.user_id, @@ -204,126 +205,52 @@ def _from_amqp_message(self, message): ) if message.properties else None self._header = AmqpMessageHeader( delivery_count=message.header.delivery_count, - time_to_live=message.header.time_to_live, + time_to_live=message.header.ttl, first_acquirer=message.header.first_acquirer, durable=message.header.durable, priority=message.header.priority ) if message.header else None self._footer = message.footer - self._annotations = message.annotations + self._annotations = message.message_annotations self._delivery_annotations = message.delivery_annotations self._application_properties = message.application_properties - - def _to_outgoing_amqp_message(self): - message_header = None - ttl_set = False - if self.header: - message_header = uamqp.message.MessageHeader() - message_header.delivery_count = self.header.delivery_count - message_header.time_to_live = self.header.time_to_live - message_header.first_acquirer = self.header.first_acquirer - message_header.durable = self.header.durable - message_header.priority = self.header.priority - if self.header.time_to_live and self.header.time_to_live != MAX_DURATION_VALUE: - ttl_set = True - creation_time_from_ttl = int(time.mktime(datetime.now(timezone.utc).timetuple()) * 1000) - absolute_expiry_time_from_ttl = int(min( - MAX_ABSOLUTE_EXPIRY_TIME, - creation_time_from_ttl + self.header.time_to_live - )) - - message_properties = None - if self.properties: - creation_time = None - absolute_expiry_time = None - if ttl_set: - creation_time = creation_time_from_ttl - absolute_expiry_time = absolute_expiry_time_from_ttl - else: - if self.properties.creation_time: - creation_time = int(self.properties.creation_time) - if self.properties.absolute_expiry_time: - absolute_expiry_time = int(self.properties.absolute_expiry_time) - - message_properties = uamqp.message.MessageProperties( - message_id=self.properties.message_id, - user_id=self.properties.user_id, - to=self.properties.to, - subject=self.properties.subject, - reply_to=self.properties.reply_to, - correlation_id=self.properties.correlation_id, - content_type=self.properties.content_type, - content_encoding=self.properties.content_encoding, - creation_time=creation_time, - absolute_expiry_time=absolute_expiry_time, - group_id=self.properties.group_id, - group_sequence=self.properties.group_sequence, - reply_to_group_id=self.properties.reply_to_group_id, - encoding=self._encoding - ) - elif ttl_set: - message_properties = uamqp.message.MessageProperties( - creation_time=creation_time_from_ttl if ttl_set else None, - absolute_expiry_time=absolute_expiry_time_from_ttl if ttl_set else None, - ) - - amqp_body = self._message._body # pylint: disable=protected-access - if isinstance(amqp_body, uamqp.message.DataBody): - amqp_body_type = uamqp.MessageBodyType.Data - amqp_body = list(amqp_body.data) - elif isinstance(amqp_body, uamqp.message.SequenceBody): - amqp_body_type = uamqp.MessageBodyType.Sequence - amqp_body = list(amqp_body.data) + if message.data: + self._data_body = cast(Iterable, list(message.data)) + self._body_type = AmqpMessageBodyType.DATA + elif message.sequence: + self._sequence_body = cast(Iterable, list(message.sequence)) + self._body_type = AmqpMessageBodyType.SEQUENCE else: - # amqp_body is type of uamqp.message.ValueBody - amqp_body_type = uamqp.MessageBodyType.Value - amqp_body = amqp_body.data - - return uamqp.message.Message( - body=amqp_body, - body_type=amqp_body_type, - header=message_header, - properties=message_properties, - application_properties=self.application_properties, - annotations=self.annotations, - delivery_annotations=self.delivery_annotations, - footer=self.footer - ) - - def _to_outgoing_message(self, message_type): - # convert to an outgoing ServiceBusMessage - return message_type(body=None, message=self._to_outgoing_amqp_message(), raw_amqp_message=self) + self._value_body = message.value + self._body_type = AmqpMessageBodyType.VALUE @property def body(self) -> Any: """The body of the Message. The format may vary depending on the body type: - For :class:`azure.servicebus.amqp.AmqpMessageBodyType.DATA`, - the body could be bytes or Iterable[bytes]. - For - :class:`azure.servicebus.amqp.AmqpMessageBodyType.SEQUENCE`, - the body could be List or Iterable[List]. - For :class:`azure.servicebus.amqp.AmqpMessageBodyType.VALUE`, - the body could be any type. - + For ~azure.servicebus.AmqpMessageBodyType.DATA, the body could be bytes or Iterable[bytes] + For ~azure.servicebus.AmqpMessageBodyType.SEQUENCE, the body could be List or Iterable[List] + For ~azure.servicebus.AmqpMessageBodyType.VALUE, the body could be any type. :rtype: Any """ - return self._message.get_data() + if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return + return (i for i in cast(Iterable, self._data_body)) + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + return (i for i in cast(Iterable, self._sequence_body)) + elif self._body_type == AmqpMessageBodyType.VALUE: + return self._value_body + return None @property def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. - - :rtype: ~azure.servicebus.amqp.AmqpMessageBodyType + rtype: ~azure.servicebus.amqp.AmqpMessageBodyType """ - return AMQP_MESSAGE_BODY_TYPE_MAP.get( - self._message._body.type, AmqpMessageBodyType.VALUE # pylint: disable=protected-access - ) + return self._body_type @property def properties(self) -> Optional[AmqpMessageProperties]: """ Properties to add to the message. - :rtype: Optional[~azure.servicebus.amqp.AmqpMessageProperties] """ return self._properties @@ -333,20 +260,20 @@ def properties(self, value: AmqpMessageProperties) -> None: self._properties = value @property - def application_properties(self) -> Optional[Dict]: + def application_properties(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific application properties. - :rtype: Optional[dict] + :rtype: Optional[Dict] """ return self._application_properties @application_properties.setter - def application_properties(self, value: Dict) -> None: + def application_properties(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._application_properties = value @property - def annotations(self) -> Optional[Dict]: + def annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific message annotations. @@ -355,11 +282,11 @@ def annotations(self) -> Optional[Dict]: return self._annotations @annotations.setter - def annotations(self, value: Dict) -> None: + def annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._annotations = value @property - def delivery_annotations(self) -> Optional[Dict]: + def delivery_annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Delivery-specific non-standard properties at the head of the message. Delivery annotations convey information from the sending peer to the receiving peer. @@ -369,14 +296,13 @@ def delivery_annotations(self) -> Optional[Dict]: return self._delivery_annotations @delivery_annotations.setter - def delivery_annotations(self, value: Dict) -> None: + def delivery_annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._delivery_annotations = value @property def header(self) -> Optional[AmqpMessageHeader]: """ The message header. - :rtype: Optional[~azure.servicebus.amqp.AmqpMessageHeader] """ return self._header @@ -386,7 +312,7 @@ def header(self, value: AmqpMessageHeader) -> None: self._header = value @property - def footer(self) -> Optional[Dict]: + def footer(self) -> Optional[Dict[Any, Any]]: """ The message footer. @@ -395,7 +321,7 @@ def footer(self) -> Optional[Dict]: return self._footer @footer.setter - def footer(self, value: Dict) -> None: + def footer(self, value: Optional[Dict[Any, Any]]) -> None: self._footer = value @@ -404,7 +330,6 @@ class AmqpMessageHeader(DictMixin): The Message header. This is only used on received message, and not set on messages being sent. The properties set on any given message will depend on the Service and not all messages will have all properties. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-header for more information on the message header. @@ -436,7 +361,6 @@ class AmqpMessageHeader(DictMixin): :keyword priority: This field contains the relative message priority. Higher numbers indicate higher priority messages. Messages with higher priorities MAY be delivered before those with lower priorities. :paramtype priority: Optional[int] - :ivar delivery_count: The number of unsuccessful previous attempts to deliver this message. If this value is non-zero it can be taken as an indication that the delivery might be a duplicate. On first delivery, the value is zero. It is @@ -490,13 +414,12 @@ class AmqpMessageProperties(DictMixin): The properties that are actually used will depend on the service implementation. Not all received messages will have all properties, and not all properties will be utilized on a sent message. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-properties for more information on the message properties. :keyword message_id: Message-id, if set, uniquely identifies a message within the message system. - The message producer is usually responsible for setting the message-id in such a way that it + The message sender is usually responsible for setting the message-id in such a way that it is assured to be globally unique. A broker MAY discard a message as a duplicate if the value of the message-id matches that of a previously received message sent to the same node. :paramtype message_id: Optional[Union[str, bytes, uuid.UUID]] @@ -527,9 +450,8 @@ class AmqpMessageProperties(DictMixin): :keyword reply_to_group_id: This is a client-specific id that is used so that client can send replies to this message to a specific group. :paramtype reply_to_group_id: Optional[Union[str, bytes]] - :ivar message_id: Message-id, if set, uniquely identifies a message within the message system. - The message producer is usually responsible for setting the message-id in such a way that it + The message sender is usually responsible for setting the message-id in such a way that it is assured to be globally unique. A broker MAY discard a message as a duplicate if the value of the message-id matches that of a previously received message sent to the same node. :vartype message_id: Optional[Union[str, bytes, uuid.UUID]] @@ -564,7 +486,7 @@ class AmqpMessageProperties(DictMixin): def __init__( self, *, - message_id: Optional[Union[str, bytes, uuid.UUID]] = None, + message_id: Optional[Union[str, bytes, "uuid.UUID"]] = None, user_id: Optional[Union[str, bytes]] = None, to: Optional[Union[str, bytes]] = None, subject: Optional[Union[str, bytes]] = None, diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_utils.py b/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_utils.py new file mode 100644 index 0000000000000..c620c149ea5e1 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_amqp_utils.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +def encode_str(data, encoding='utf-8'): + try: + return data.encode(encoding) + except AttributeError: + return data + +def normalized_data_body(data, **kwargs): + # A helper method to normalize input into AMQP Data Body format + encoding = kwargs.get("encoding", "utf-8") + if isinstance(data, list): + return [encode_str(item, encoding) for item in data] + return [encode_str(data, encoding)] + +def normalized_sequence_body(sequence): + # A helper method to normalize input into AMQP Sequence Body format + if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): + return sequence + if isinstance(sequence, list): + return [sequence] diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_constants.py b/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_constants.py index 05ea858bcfc6e..01808178fb9a0 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_constants.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/amqp/_constants.py @@ -4,18 +4,9 @@ # license information. # ------------------------------------------------------------------------- from enum import Enum - -from uamqp import MessageBodyType from azure.core import CaseInsensitiveEnumMeta class AmqpMessageBodyType(str, Enum, metaclass=CaseInsensitiveEnumMeta): DATA = "data" SEQUENCE = "sequence" VALUE = "value" - - -AMQP_MESSAGE_BODY_TYPE_MAP = { - MessageBodyType.Data.value: AmqpMessageBodyType.DATA, - MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, - MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, -} diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/exceptions.py b/sdk/servicebus/azure-servicebus/azure/servicebus/exceptions.py index 4baaaa4c1766b..33af8995e8df1 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/exceptions.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/exceptions.py @@ -6,205 +6,8 @@ from typing import Any -from uamqp import errors as AMQPErrors, constants -from uamqp.constants import ErrorCodes as AMQPErrorCodes from azure.core.exceptions import AzureError -from ._common.constants import ( - ERROR_CODE_SESSION_LOCK_LOST, - ERROR_CODE_MESSAGE_LOCK_LOST, - ERROR_CODE_MESSAGE_NOT_FOUND, - ERROR_CODE_TIMEOUT, - ERROR_CODE_AUTH_FAILED, - ERROR_CODE_SESSION_CANNOT_BE_LOCKED, - ERROR_CODE_SERVER_BUSY, - ERROR_CODE_ARGUMENT_ERROR, - ERROR_CODE_OUT_OF_RANGE, - ERROR_CODE_ENTITY_DISABLED, - ERROR_CODE_ENTITY_ALREADY_EXISTS, - ERROR_CODE_PRECONDITION_FAILED, -) - - -_NO_RETRY_CONDITION_ERROR_CODES = ( - constants.ErrorCodes.DecodeError, - constants.ErrorCodes.LinkMessageSizeExceeded, - constants.ErrorCodes.NotFound, - constants.ErrorCodes.NotImplemented, - constants.ErrorCodes.LinkRedirect, - constants.ErrorCodes.NotAllowed, - constants.ErrorCodes.UnauthorizedAccess, - constants.ErrorCodes.LinkStolen, - constants.ErrorCodes.ResourceLimitExceeded, - constants.ErrorCodes.ConnectionRedirect, - constants.ErrorCodes.PreconditionFailed, - constants.ErrorCodes.InvalidField, - constants.ErrorCodes.ResourceDeleted, - constants.ErrorCodes.IllegalState, - constants.ErrorCodes.FrameSizeTooSmall, - constants.ErrorCodes.ConnectionFramingError, - constants.ErrorCodes.SessionUnattachedHandle, - constants.ErrorCodes.SessionHandleInUse, - constants.ErrorCodes.SessionErrantLink, - constants.ErrorCodes.SessionWindowViolation, - ERROR_CODE_SESSION_LOCK_LOST, - ERROR_CODE_MESSAGE_LOCK_LOST, - ERROR_CODE_OUT_OF_RANGE, - ERROR_CODE_ARGUMENT_ERROR, - ERROR_CODE_PRECONDITION_FAILED, -) - - -def _error_handler(error): - """Handle connection and service errors. - - Called internally when an event has failed to send so we - can parse the error to determine whether we should attempt - to retry sending the event again. - Returns the action to take according to error type. - - :param error: The error received in the send attempt. - :type error: Exception - :rtype: ~uamqp.errors.ErrorAction - """ - if error.condition == b"com.microsoft:server-busy": - return AMQPErrors.ErrorAction(retry=True, backoff=4) - if error.condition == b"com.microsoft:timeout": - return AMQPErrors.ErrorAction(retry=True, backoff=2) - if error.condition == b"com.microsoft:operation-cancelled": - return AMQPErrors.ErrorAction(retry=True) - if error.condition == b"com.microsoft:container-close": - return AMQPErrors.ErrorAction(retry=True, backoff=4) - if error.condition in _NO_RETRY_CONDITION_ERROR_CODES: - return AMQPErrors.ErrorAction(retry=False) - return AMQPErrors.ErrorAction(retry=True) - - -def _handle_amqp_exception_with_condition( - logger, condition, description, exception=None, status_code=None -): - # - # handling AMQP Errors that have the condition field or the mgmt handler - logger.info( - "AMQP error occurred: (%r), condition: (%r), description: (%r).", - exception, - condition, - description, - ) - if condition == AMQPErrorCodes.NotFound: - # handle NotFound error code - error_cls = ( - ServiceBusCommunicationError - if isinstance(exception, AMQPErrors.AMQPConnectionError) - else MessagingEntityNotFoundError - ) - elif condition == AMQPErrorCodes.ClientError and "timed out" in str(exception): - # handle send timeout - error_cls = OperationTimeoutError - elif condition == AMQPErrorCodes.UnknownError and isinstance(exception, AMQPErrors.AMQPConnectionError): - error_cls = ServiceBusConnectionError - else: - # handle other error codes - error_cls = _ERROR_CODE_TO_ERROR_MAPPING.get(condition, ServiceBusError) - - error = error_cls( - message=description, - error=exception, - condition=condition, - status_code=status_code, - ) - if condition in _NO_RETRY_CONDITION_ERROR_CODES: - error._retryable = False # pylint: disable=protected-access - else: - error._retryable = True # pylint: disable=protected-access - - return error - - -def _handle_amqp_exception_without_condition(logger, exception): - error_cls = ServiceBusError - if isinstance(exception, AMQPErrors.AMQPConnectionError): - logger.info("AMQP Connection error occurred: (%r).", exception) - error_cls = ServiceBusConnectionError - elif isinstance(exception, AMQPErrors.AuthenticationException): - logger.info("AMQP Connection authentication error occurred: (%r).", exception) - error_cls = ServiceBusAuthenticationError - elif isinstance(exception, AMQPErrors.MessageException): - logger.info("AMQP Message error occurred: (%r).", exception) - if isinstance(exception, AMQPErrors.MessageAlreadySettled): - error_cls = MessageAlreadySettled - elif isinstance(exception, AMQPErrors.MessageContentTooLarge): - error_cls = MessageSizeExceededError - else: - logger.info( - "Unexpected AMQP error occurred (%r). Handler shutting down.", exception - ) - - error = error_cls(message=str(exception), error=exception) - return error - - -def _handle_amqp_mgmt_error( - logger, error_description, condition=None, description=None, status_code=None -): - if description: - error_description += " {}.".format(description) - - raise _handle_amqp_exception_with_condition( - logger, - condition, - description=error_description, - exception=None, - status_code=status_code, - ) - - -def _create_servicebus_exception(logger, exception): - if isinstance(exception, AMQPErrors.AMQPError): - try: - # handling AMQP Errors that have the condition field - condition = exception.condition - description = exception.description - exception = _handle_amqp_exception_with_condition( - logger, condition, description, exception=exception - ) - except AttributeError: - # handling AMQP Errors that don't have the condition field - exception = _handle_amqp_exception_without_condition(logger, exception) - elif not isinstance(exception, ServiceBusError): - logger.exception( - "Unexpected error occurred (%r). Handler shutting down.", exception - ) - exception = ServiceBusError( - message="Handler failed: {}.".format(exception), error=exception - ) - - return exception - - -class _ServiceBusErrorPolicy(AMQPErrors.ErrorPolicy): - def __init__(self, max_retries=3, is_session=False): - self._is_session = is_session - super(_ServiceBusErrorPolicy, self).__init__( - max_retries=max_retries, on_error=_error_handler - ) - - def on_unrecognized_error(self, error): - if self._is_session: - return AMQPErrors.ErrorAction(retry=False) - return super(_ServiceBusErrorPolicy, self).on_unrecognized_error(error) - - def on_link_error(self, error): - if self._is_session: - return AMQPErrors.ErrorAction(retry=False) - return super(_ServiceBusErrorPolicy, self).on_link_error(error) - - def on_connection_error(self, error): - if self._is_session: - return AMQPErrors.ErrorAction(retry=False) - return super(_ServiceBusErrorPolicy, self).on_connection_error(error) - - class ServiceBusError(AzureError): """Base exception for all Service Bus errors which can be used for default error handling. @@ -487,24 +290,3 @@ class AutoLockRenewFailed(ServiceBusError): class AutoLockRenewTimeout(ServiceBusError): """The time allocated to renew the message or session lock has elapsed.""" - - -_ERROR_CODE_TO_ERROR_MAPPING = { - AMQPErrorCodes.LinkMessageSizeExceeded: MessageSizeExceededError, - AMQPErrorCodes.ResourceLimitExceeded: ServiceBusQuotaExceededError, - AMQPErrorCodes.UnauthorizedAccess: ServiceBusAuthorizationError, - AMQPErrorCodes.NotImplemented: ServiceBusError, - AMQPErrorCodes.NotAllowed: ServiceBusError, - AMQPErrorCodes.LinkDetachForced: ServiceBusConnectionError, - ERROR_CODE_MESSAGE_LOCK_LOST: MessageLockLostError, - ERROR_CODE_MESSAGE_NOT_FOUND: MessageNotFoundError, - ERROR_CODE_AUTH_FAILED: ServiceBusAuthorizationError, - ERROR_CODE_ENTITY_DISABLED: MessagingEntityDisabledError, - ERROR_CODE_ENTITY_ALREADY_EXISTS: MessagingEntityAlreadyExistsError, - ERROR_CODE_SERVER_BUSY: ServiceBusServerBusyError, - ERROR_CODE_SESSION_CANNOT_BE_LOCKED: SessionCannotBeLockedError, - ERROR_CODE_SESSION_LOCK_LOST: SessionLockLostError, - ERROR_CODE_ARGUMENT_ERROR: ServiceBusError, - ERROR_CODE_OUT_OF_RANGE: ServiceBusError, - ERROR_CODE_TIMEOUT: OperationTimeoutError, -} diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py index 45a562306153a..9e3421f6067e1 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py @@ -224,7 +224,7 @@ def _create_forward_to_header_tokens(self, entity, kwargs): kwargs["headers"] = kwargs.get("headers", {}) def _populate_header_within_kwargs(uri, header): - token = self._credential.get_token(uri).token.decode() + token = self._credential.get_token(uri).token if not isinstance( self._credential, (ServiceBusSASTokenCredential, ServiceBusSharedKeyCredential), diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_shared_key_policy.py b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_shared_key_policy.py index 654ecc4dedd21..88238125d5ad3 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_shared_key_policy.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_shared_key_policy.py @@ -37,7 +37,7 @@ def _update_token(self): access_token, self._token_expiry_on = self._credential.get_token( self._endpoint ) - self._token = access_token.decode("utf-8") + self._token = access_token def on_request(self, request): # type: (PipelineRequest) -> None diff --git a/sdk/servicebus/azure-servicebus/conftest.py b/sdk/servicebus/azure-servicebus/conftest.py index c134fc34d88f6..d91636ed6fb37 100644 --- a/sdk/servicebus/azure-servicebus/conftest.py +++ b/sdk/servicebus/azure-servicebus/conftest.py @@ -13,8 +13,6 @@ add_oauth_response_sanitizer, set_custom_default_matcher ) - - collect_ignore = [] @pytest.fixture(scope="session", autouse=True) diff --git a/sdk/servicebus/azure-servicebus/dev_requirements.txt b/sdk/servicebus/azure-servicebus/dev_requirements.txt index 3562e477bc943..8902ca99f4a66 100644 --- a/sdk/servicebus/azure-servicebus/dev_requirements.txt +++ b/sdk/servicebus/azure-servicebus/dev_requirements.txt @@ -4,4 +4,6 @@ -e ../../../tools/azure-sdk-tools azure-mgmt-servicebus~=8.0.0 aiohttp>=3.0 +websocket-client +uamqp>=1.6.3,<2.0.0 azure-mgmt-resource<=16.0.0 \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/setup.py b/sdk/servicebus/azure-servicebus/setup.py index b940b9e6ec684..b9f2a29aca453 100644 --- a/sdk/servicebus/azure-servicebus/setup.py +++ b/sdk/servicebus/azure-servicebus/setup.py @@ -64,7 +64,6 @@ 'azure', ]), install_requires=[ - "uamqp>=1.6.3,<2.0.0", "azure-core<2.0.0,>=1.24.0", "isodate>=0.6.0", "typing-extensions>=4.0.1", diff --git a/sdk/servicebus/azure-servicebus/stress/.helmignore b/sdk/servicebus/azure-servicebus/stress/.helmignore new file mode 100644 index 0000000000000..61334a0f31eec --- /dev/null +++ b/sdk/servicebus/azure-servicebus/stress/.helmignore @@ -0,0 +1,6 @@ +stress +stress.exe +.env +Dockerfile +*.py +*.txt \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/stress/Chart.lock b/sdk/servicebus/azure-servicebus/stress/Chart.lock index f5c554172d821..15552cf4dea9b 100644 --- a/sdk/servicebus/azure-servicebus/stress/Chart.lock +++ b/sdk/servicebus/azure-servicebus/stress/Chart.lock @@ -2,5 +2,5 @@ dependencies: - name: stress-test-addons repository: https://stresstestcharts.blob.core.windows.net/helm/ version: 0.2.0 -digest: sha256:59fff3930e78c4ca9f9c0120433c7695d31db63f36ac61d50abcc91b1f1835a0 -generated: "2022-11-19T01:30:02.403917379Z" +digest: sha256:53cbe4c0fed047f6c611523bd34181b21a310e7a3a21cb14f649bb09e4a77648 +generated: "2023-03-14T09:57:20.6731895-07:00" diff --git a/sdk/servicebus/azure-servicebus/stress/Chart.yaml b/sdk/servicebus/azure-servicebus/stress/Chart.yaml index 3f017a8c86d9d..2bb897675d676 100644 --- a/sdk/servicebus/azure-servicebus/stress/Chart.yaml +++ b/sdk/servicebus/azure-servicebus/stress/Chart.yaml @@ -1,5 +1,5 @@ apiVersion: v2 -name: python-servicebus-stress-test +name: py-sb-stress-test description: python service bus stress test. version: 0.1.2 appVersion: v0.2 diff --git a/sdk/servicebus/azure-servicebus/stress/Dockerfile b/sdk/servicebus/azure-servicebus/stress/Dockerfile index 61c94647f7eec..96439c796da8f 100644 --- a/sdk/servicebus/azure-servicebus/stress/Dockerfile +++ b/sdk/servicebus/azure-servicebus/stress/Dockerfile @@ -2,10 +2,12 @@ # public OSS users should simply leave this argument blank or ignore its presence entirely ARG REGISTRY="mcr.microsoft.com/mirror/docker/library/" FROM ${REGISTRY}python:3.8-slim-buster +# Install if running off git branch +# RUN apt-get -y update && apt-get -y install git WORKDIR /app COPY ./scripts /app/stress/scripts WORKDIR /app/stress/scripts -RUN pip3 install -r dev_requirements.txt +RUN pip install -r dev_requirements.txt \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/stress/scenarios-matrix.yaml b/sdk/servicebus/azure-servicebus/stress/scenarios-matrix.yaml index 4a4bbec0136e9..04b654df2476a 100644 --- a/sdk/servicebus/azure-servicebus/stress/scenarios-matrix.yaml +++ b/sdk/servicebus/azure-servicebus/stress/scenarios-matrix.yaml @@ -1,7 +1,32 @@ -displayNames: matrix: - image: + image: - Dockerfile scenarios: - sbStress: - testTarget: servicebus \ No newline at end of file + queue: + testTarget: queue + aqueue: + testTarget: aqueue + queuepull: + testTarget: queuepull + aqueuepull: + testTarget: aqueuepull + batch: + testTarget: batch + abatch: + testTarget: abatch + queuew: + testTarget: queuew + aqueuew: + testTarget: aqueuew + queuepullw: + testTarget: queuepullw + aqueuepullw: + testTarget: aqueuepullw + batchw: + testTarget: batchw + abatchw: + testTarget: abatchw + memray: + testTarget: memray + amemray: + testTarget: amemray diff --git a/sdk/servicebus/azure-servicebus/stress/scripts/dev_requirements.txt b/sdk/servicebus/azure-servicebus/stress/scripts/dev_requirements.txt index d7928a902bf49..dc6a5d000de97 100644 --- a/sdk/servicebus/azure-servicebus/stress/scripts/dev_requirements.txt +++ b/sdk/servicebus/azure-servicebus/stress/scripts/dev_requirements.txt @@ -1,6 +1,6 @@ -aiohttp>=3.0; python_version >= '3.5' +aiohttp>=3.0 opencensus-ext-azure psutil -pytest azure-servicebus python-dotenv +websocket-client \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/stress/scripts/logger.py b/sdk/servicebus/azure-servicebus/stress/scripts/logger.py index 0875c2e8a7e09..3199a5e59afb2 100644 --- a/sdk/servicebus/azure-servicebus/stress/scripts/logger.py +++ b/sdk/servicebus/azure-servicebus/stress/scripts/logger.py @@ -11,77 +11,37 @@ from opencensus.ext.azure.log_exporter import AzureLogHandler -def get_base_logger(log_filename, logger_name, level=logging.INFO, print_console=False, log_format=None, +def get_base_logger(log_filename, logger_name, level=logging.ERROR, print_console=False, log_format=None, log_file_max_bytes=20 * 1024 * 1024, log_file_backup_count=3): logger = logging.getLogger(logger_name) logger.setLevel(level) formatter = log_format or logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s') - - if print_console: - console_handler = logging.StreamHandler(stream=sys.stdout) - if not logger.handlers: - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - if log_filename: - logger_file_handler = RotatingFileHandler( - log_filename, - maxBytes=log_file_max_bytes, - backupCount=log_file_backup_count - ) - logger_file_handler.setFormatter(formatter) - logger.addHandler(logger_file_handler) - + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) return logger - -def get_logger(log_filename, logger_name, level=logging.INFO, print_console=False, log_format=None, +def get_logger(log_filename, logger_name, level=logging.ERROR, print_console=False, log_format=None, log_file_max_bytes=20 * 1024 * 1024, log_file_backup_count=3): stress_logger = logging.getLogger(logger_name) stress_logger.setLevel(level) - eventhub_logger = logging.getLogger("azure.eventhub") - eventhub_logger.setLevel(level) - uamqp_logger = logging.getLogger("uamqp") - uamqp_logger.setLevel(level) + servicebus_logger = logging.getLogger("azure.servicebus") + servicebus_logger.setLevel(level) + pyamqp_logger = logging.getLogger("azure.servicebus._pyamqp") + pyamqp_logger.setLevel(level) formatter = log_format or logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s') - if print_console: - console_handler = logging.StreamHandler(stream=sys.stdout) - console_handler.setFormatter(formatter) - if not eventhub_logger.handlers: - eventhub_logger.addHandler(console_handler) - if not uamqp_logger.handlers: - uamqp_logger.addHandler(console_handler) - if not stress_logger.handlers: - stress_logger.addHandler(console_handler) - if log_filename: - eventhub_file_handler = RotatingFileHandler( - "eventhub_" + log_filename, - maxBytes=log_file_max_bytes, - backupCount=log_file_backup_count - ) - uamqp_file_handler = RotatingFileHandler( - "uamqp_" + log_filename, - maxBytes=log_file_max_bytes, - backupCount=log_file_backup_count - ) - stress_file_handler = RotatingFileHandler( - log_filename, - maxBytes=log_file_max_bytes, - backupCount=log_file_backup_count - ) - eventhub_file_handler.setFormatter(formatter) - uamqp_file_handler.setFormatter(formatter) - stress_file_handler.setFormatter(formatter) - eventhub_logger.addHandler(eventhub_file_handler) - uamqp_logger.addHandler(uamqp_file_handler) - stress_logger.addHandler(stress_file_handler) + console_handler = logging.FileHandler(log_filename) + console_handler.setFormatter(formatter) + servicebus_logger.addHandler(console_handler) + pyamqp_logger.addHandler(console_handler) + stress_logger.addHandler(console_handler) return stress_logger -def get_azure_logger(logger_name, level=logging.INFO): +def get_azure_logger(logger_name, level=logging.ERROR): logger = logging.getLogger("azure_logger_" + logger_name) logger.setLevel(level) # oc will automatically search for the ENV VAR 'APPLICATIONINSIGHTS_CONNECTION_STRING' diff --git a/sdk/servicebus/azure-servicebus/stress/scripts/process_monitor.py b/sdk/servicebus/azure-servicebus/stress/scripts/process_monitor.py index b59610a4331f3..8e8178c5f56b8 100644 --- a/sdk/servicebus/azure-servicebus/stress/scripts/process_monitor.py +++ b/sdk/servicebus/azure-servicebus/stress/scripts/process_monitor.py @@ -12,7 +12,7 @@ class ProcessMonitor: - def __init__(self, log_filename, logger_name, log_interval=5.0, print_console=False, + def __init__(self, log_filename, logger_name, log_interval=30.0, print_console=False, process_id=None, **kwargs): """ Process Monitor monitors the CPU usage, memory usage of a specific process. diff --git a/sdk/servicebus/azure-servicebus/stress/scripts/stress_runner.py b/sdk/servicebus/azure-servicebus/stress/scripts/stress_runner.py index c1fe97d8bca1c..fac7c5c8ab0e2 100644 --- a/sdk/servicebus/azure-servicebus/stress/scripts/stress_runner.py +++ b/sdk/servicebus/azure-servicebus/stress/scripts/stress_runner.py @@ -6,8 +6,10 @@ import os import asyncio +import configparser from argparse import ArgumentParser from datetime import timedelta +from dotenv import load_dotenv from azure.servicebus import ServiceBusClient from azure.servicebus.aio import ServiceBusClient as AsyncServiceBusClient @@ -16,9 +18,7 @@ from app_insights_metric import AzureMonitorMetric from process_monitor import ProcessMonitor -CONNECTION_STR = os.environ['SERVICE_BUS_CONNECTION_STR'] -QUEUE_NAME = os.environ["SERVICE_BUS_QUEUE_NAME"] - +ENV_FILE = os.environ.get("ENV_FILE") def sync_send(client, args): azure_monitor_metric = AzureMonitorMetric("Sync ServiceBus Sender") @@ -53,6 +53,9 @@ async def async_send(client, args): def sync_receive(client, args): + config = configparser.ConfigParser() + config.read("./stress_runner.cfg") + azure_monitor_metric = AzureMonitorMetric("Sync ServiceBus Receiver") process_monitor = ProcessMonitor("monitor_receiver_stress_sync.log", "receiver_stress_sync") stress_test = StressTestRunner( @@ -87,10 +90,14 @@ async def async_receive(client, args): if __name__ == '__main__': + load_dotenv(dotenv_path=ENV_FILE, override=True) parser = ArgumentParser() + parser.add_argument("--conn_str", help="ServiceBus connection string", + default=os.environ.get('SERVICE_BUS_CONNECTION_STR')) + parser.add_argument("--queue_name", help="The queue name.", default=os.environ.get("SERVICE_BUS_QUEUE_NAME")) parser.add_argument("--method", type=str) parser.add_argument("--duration", type=int, default=259200) - parser.add_argument("--logging-enable", action="store_true") + parser.add_argument("--logging_enable", action="store_true") parser.add_argument("--send-batch-size", type=int, default=100) parser.add_argument("--message-size", type=int, default=100) @@ -102,6 +109,9 @@ async def async_receive(client, args): args, _ = parser.parse_known_args() loop = asyncio.get_event_loop() + CONNECTION_STR = args.conn_str + QUEUE_NAME= args.queue_name + if args.method.startswith("sync"): sb_client = ServiceBusClient.from_connection_string(conn_str=CONNECTION_STR) else: diff --git a/sdk/servicebus/azure-servicebus/stress/scripts/stress_test_base.py b/sdk/servicebus/azure-servicebus/stress/scripts/stress_test_base.py index a815659f46e4b..67c016c9bfb3a 100644 --- a/sdk/servicebus/azure-servicebus/stress/scripts/stress_test_base.py +++ b/sdk/servicebus/azure-servicebus/stress/scripts/stress_test_base.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta import concurrent import sys +import os import asyncio import logging @@ -20,15 +21,13 @@ from azure.servicebus import ServiceBusMessage, ServiceBusMessageBatch from azure.servicebus.exceptions import MessageAlreadySettled - -import logger +from logger import get_logger from app_insights_metric import AbstractMonitorMetric from process_monitor import ProcessMonitor -LOGFILE_NAME = "stress-test.log" +LOGFILE_NAME = os.environ.get("DEBUG_SHARE") + "output" PRINT_CONSOLE = True - -_logger = logger.get_base_logger(LOGFILE_NAME, "stress_test", logging.WARN) +_logger = get_logger(LOGFILE_NAME, "stress_test", logging.ERROR) class ReceiveType: @@ -44,11 +43,11 @@ def __init__(self): self.time_elapsed = None self.state_by_sender = {} self.state_by_receiver = {} + self.actual_size = 0 def __repr__(self): return str(vars(self)) - class StressTestRunnerState(object): """Per-runner state, e.g. if you spawn 3 senders each will have this as their state object, which will be coalesced at completion into StressTestResults""" @@ -64,11 +63,11 @@ def __init__(self): def __repr__(self): return str(vars(self)) - def populate_process_stats(self): + def populate_process_stats(self, monitor): self.timestamp = datetime.utcnow() try: self.cpu_percent = psutil.cpu_percent() - self.memory_bytes = psutil.virtual_memory().total + self.memory_bytes = psutil.virtual_memory().percent except NameError: return # psutil was not installed, fall back to simply not capturing these stats. @@ -81,22 +80,25 @@ def __init__( self, senders, receivers, + admin_client, duration=timedelta(minutes=15), receive_type=ReceiveType.push, send_batch_size=None, message_size=10, max_wait_time=10, - send_delay=0.01, + send_delay=1.0, receive_delay=0, should_complete_messages=True, - max_message_count=1, + max_message_count=10, send_session_id=None, fail_on_exception=True, azure_monitor_metric=None, process_monitor=None, + logging_level=logging.ERROR, ): self.senders = senders self.receivers = receivers + self.admin_client = admin_client self.duration = duration self.receive_type = receive_type self.message_size = message_size @@ -111,6 +113,7 @@ def __init__( self.azure_monitor_metric = azure_monitor_metric or AbstractMonitorMetric( "fake_test_name" ) + self.logging_level = logging_level self.process_monitor = process_monitor or ProcessMonitor( "monitor_{}".format(LOGFILE_NAME), "test_stress_queues", @@ -123,7 +126,7 @@ def __init__( self._duration_override = None for arg in sys.argv: - if arg.startswith("--stress_test_duration_seconds="): + if arg.startswith("--duration="): self._duration_override = timedelta(seconds=int(arg.split("=")[1])) self._should_stop = False @@ -161,10 +164,10 @@ def pre_process_message_body(self, payload): """Allows user to transform message payload before sending it.""" return payload - def _schedule_interval_logger(self, end_time, description="", interval_seconds=30): + def _schedule_interval_logger(self, end_time, description="", interval_seconds=300): def _do_interval_logging(): if end_time > datetime.utcnow() and not self._should_stop: - self._state.populate_process_stats() + self._state.populate_process_stats(self.process_monitor) _logger.critical( "{} RECURRENT STATUS: {}".format(description, self._state) ) @@ -194,14 +197,16 @@ def _construct_message(self): def _send(self, sender, end_time): self._schedule_interval_logger(end_time, "Sender " + str(self)) try: - _logger.info("STARTING SENDER") + _logger.debug("Starting send loop") + # log sender + _logger.debug("Sender: {}".format(sender)) with sender: while end_time > datetime.utcnow() and not self._should_stop: - _logger.info("SENDING") try: message = self._construct_message() if self.send_session_id != None: message.session_id = self.send_session_id + _logger.debug("Sending message: {}".format(message)) sender.send_messages(message) self.azure_monitor_metric.record_messages_cpu_memory( self.send_batch_size, @@ -213,6 +218,7 @@ def _send(self, sender, end_time): else: self._state.total_sent += 1 # send single message self.on_send(self._state, message, sender) + except Exception as e: _logger.exception("Exception during send: {}".format(e)) self.azure_monitor_metric.record_error(e) @@ -229,10 +235,11 @@ def _send(self, sender, end_time): def _receive(self, receiver, end_time): self._schedule_interval_logger(end_time, "Receiver " + str(self)) + # log receiver + _logger.debug("Receiver: {}".format(receiver)) try: with receiver: while end_time > datetime.utcnow() and not self._should_stop: - _logger.info("RECEIVE LOOP") try: if self.receive_type == ReceiveType.pull: batch = receiver.receive_messages( @@ -240,19 +247,21 @@ def _receive(self, receiver, end_time): max_wait_time=self.max_wait_time, ) elif self.receive_type == ReceiveType.push: - batch = receiver._get_streaming_message_iter( - max_wait_time=self.max_wait_time - ) - else: - batch = [] + receiver.max_wait_time = self.max_wait_time + batch = receiver + # else: + # batch = [] for message in batch: + # log reciever + _logger.debug("Received message: {}".format(message)) self.on_receive(self._state, message, receiver) try: if self.should_complete_messages: receiver.complete_message(message) except MessageAlreadySettled: # It may have been settled in the plugin callback. pass + self._state.total_received += 1 # TODO: Get EnqueuedTimeUtc out of broker properties and calculate latency. Should properties/app properties be mostly None? if end_time <= datetime.utcnow(): @@ -270,7 +279,7 @@ def _receive(self, receiver, end_time): self.azure_monitor_metric.record_error(e) if self.fail_on_exception: raise - self._state.timestamp = datetime.utcnow() + self._state.timestamp = datetime.utcnow() return self._state except Exception as e: self.azure_monitor_metric.record_error(e) @@ -279,20 +288,29 @@ def _receive(self, receiver, end_time): raise def run(self): + start_time = datetime.utcnow() + if isinstance(self.duration, int): + self.duration = timedelta(seconds=self.duration) end_time = start_time + (self._duration_override or self.duration) + with self.process_monitor: with concurrent.futures.ThreadPoolExecutor(max_workers=4) as proc_pool: _logger.info("STARTING PROC POOL") - senders = [ - proc_pool.submit(self._send, sender, end_time) - for sender in self.senders - ] - receivers = [ - proc_pool.submit(self._receive, receiver, end_time) - for receiver in self.receivers - ] - + if self.senders: + senders = [ + proc_pool.submit(self._send, sender, end_time) + for sender in self.senders + ] + else: + senders = [] + if self.receivers: + receivers = [ + proc_pool.submit(self._receive, receiver, end_time) + for receiver in self.receivers + ] + else: + receivers = [] result = StressTestResults() for each in concurrent.futures.as_completed(senders + receivers): _logger.info("SOMETHING FINISHED") @@ -301,25 +319,28 @@ def run(self): if each in receivers: result.state_by_receiver[each] = each.result() # TODO: do as_completed in one batch to provide a way to short-circuit on failure. - result.state_by_sender = { - s: f.result() - for s, f in zip( - self.senders, concurrent.futures.as_completed(senders) + if self.senders: + result.state_by_sender = { + s: f.result() + for s, f in zip( + self.senders, concurrent.futures.as_completed(senders) + ) + } + _logger.info("Got receiver results") + result.total_sent = sum( + [r.total_sent for r in result.state_by_sender.values()] ) - } - result.state_by_receiver = { - r: f.result() - for r, f in zip( - self.receivers, concurrent.futures.as_completed(receivers) + if self.receivers: + result.state_by_receiver = { + r: f.result() + for r, f in zip( + self.receivers, concurrent.futures.as_completed(receivers) + ) + } + + result.total_received = sum( + [r.total_received for r in result.state_by_receiver.values()] ) - } - _logger.info("got receiver results") - result.total_sent = sum( - [r.total_sent for r in result.state_by_sender.values()] - ) - result.total_received = sum( - [r.total_received for r in result.state_by_receiver.values()] - ) result.time_elapsed = end_time - start_time _logger.critical("Stress test completed. Results:\n{}".format(result)) return result @@ -331,11 +352,12 @@ def __init__( senders, receivers, duration=timedelta(minutes=15), + admin_client=None, receive_type=ReceiveType.push, send_batch_size=None, message_size=10, max_wait_time=10, - send_delay=0.01, + send_delay=1.00, receive_delay=0, should_complete_messages=True, max_message_count=1, @@ -343,11 +365,13 @@ def __init__( fail_on_exception=True, azure_monitor_metric=None, process_monitor=None, + logging_level=logging.ERROR, ): super(StressTestRunnerAsync, self).__init__( senders, receivers, duration=duration, + admin_client=admin_client, receive_type=receive_type, send_batch_size=send_batch_size, message_size=message_size, @@ -360,15 +384,14 @@ def __init__( fail_on_exception=fail_on_exception, azure_monitor_metric=azure_monitor_metric, process_monitor=process_monitor, + logging_level=logging_level ) async def _send_async(self, sender, end_time): self._schedule_interval_logger(end_time, "Sender " + str(self)) try: - _logger.info("STARTING SENDER") async with sender: while end_time > datetime.utcnow() and not self._should_stop: - _logger.info("SENDING") try: message = self._construct_message() if self.send_session_id != None: @@ -379,7 +402,10 @@ async def _send_async(self, sender, end_time): self.process_monitor.cpu_usage_percent, self.process_monitor.memory_usage_percent, ) - self._state.total_sent += self.send_batch_size + if self.send_batch_size: + self._state.total_sent += self.send_batch_size + else: + self._state.total_sent += 1 self.on_send(self._state, message, sender) except Exception as e: _logger.exception("Exception during send: {}".format(e)) @@ -416,7 +442,6 @@ async def _receive_async(self, receiver, end_time): try: async with receiver: while end_time > datetime.utcnow() and not self._should_stop: - _logger.info("RECEIVE LOOP") try: if self.receive_type == ReceiveType.pull: batch = await receiver.receive_messages( @@ -428,9 +453,8 @@ async def _receive_async(self, receiver, end_time): message, receiver, end_time ) elif self.receive_type == ReceiveType.push: - batch = receiver._get_streaming_message_iter( - max_wait_time=self.max_wait_time - ) + receiver.max_wait_time = self.max_wait_time + batch = receiver async for message in batch: if end_time <= datetime.utcnow(): break @@ -454,31 +478,47 @@ async def _receive_async(self, receiver, end_time): async def run_async(self): start_time = datetime.utcnow() + if isinstance(self.duration, int): + self.duration = timedelta(seconds=self.duration) end_time = start_time + (self._duration_override or self.duration) - send_tasks = [ - asyncio.create_task(self._send_async(sender, end_time)) - for sender in self.senders - ] - receive_tasks = [ - asyncio.create_task(self._receive_async(receiver, end_time)) - for receiver in self.receivers - ] + if self.senders: + send_tasks = [ + asyncio.create_task(self._send_async(sender, end_time)) + for sender in self.senders + ] + else: + send_tasks = [] + if self.receivers: + receive_tasks = [ + asyncio.create_task(self._receive_async(receiver, end_time)) + for receiver in self.receivers + ] + else: + receive_tasks = [] with self.process_monitor: - await asyncio.gather(*send_tasks, *receive_tasks) + # await asyncio.gather(*send_tasks, *receive_tasks) + for task in asyncio.as_completed(send_tasks + receive_tasks): + try: + await task + except Exception as e: + print(e) result = StressTestResults() - result.state_by_sender = { - s: f.result() for s, f in zip(self.senders, send_tasks) - } - result.state_by_receiver = { - r: f.result() for r, f in zip(self.receivers, receive_tasks) - } - _logger.info("got receiver results") - result.total_sent = sum( + if self.senders: + result.state_by_sender = { + s: f.result() for s, f in zip(self.senders, send_tasks) + } + result.total_sent = sum( [r.total_sent for r in result.state_by_sender.values()] - ) - result.total_received = sum( - [r.total_received for r in result.state_by_receiver.values()] - ) + ) + if self.receivers: + result.state_by_receiver = { + r: f.result() for r, f in zip(self.receivers, receive_tasks) + } + _logger.info("got receiver results") + + result.total_received = sum( + [r.total_received for r in result.state_by_receiver.values()] + ) result.time_elapsed = end_time - start_time _logger.critical("Stress test completed. Results:\n{}".format(result)) return result diff --git a/sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues.py b/sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues.py index 838bafbc46255..6288696d6cd02 100644 --- a/sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues.py +++ b/sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues.py @@ -1,147 +1,138 @@ -#------------------------------------------------------------------------- +#------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. #-------------------------------------------------------------------------- from datetime import timedelta +import logging import time import os -import pytest from dotenv import load_dotenv -#from argparse import ArgumentParser +from argparse import ArgumentParser -from azure.servicebus import AutoLockRenewer, ServiceBusClient +from azure.servicebus import AutoLockRenewer, ServiceBusClient, TransportType +from azure.servicebus.management import ServiceBusAdministrationClient from azure.servicebus._common.constants import ServiceBusReceiveMode from app_insights_metric import AzureMonitorMetric from stress_test_base import StressTestRunner, ReceiveType ENV_FILE = os.environ.get('ENV_FILE') -load_dotenv(dotenv_path=ENV_FILE, override=True) -LOGGING_ENABLE = False -SERVICE_BUS_CONNECTION_STR = os.environ.get('SERVICE_BUS_CONNECTION_STR') -SERVICEBUS_QUEUE_NAME = os.environ.get('SERVICE_BUS_QUEUE_NAME') -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_send_and_receive(): +def test_stress_queue_send_and_receive(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = StressTestRunner(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], - duration=timedelta(seconds=60), - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_send_and_receive") + admin_client = sb_admin_client, + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_send_and_receive"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") - -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_send_and_pull_receive(): +def test_stress_queue_send_and_pull_receive(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = StressTestRunner(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], + admin_client = sb_admin_client, receive_type=ReceiveType.pull, - duration=timedelta(seconds=60), - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_send_and_pull_receive") + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_send_and_pull_receive"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) - + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_batch_send_and_receive(): +def test_stress_queue_batch_send_and_receive(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = StressTestRunner(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], - receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], - duration=timedelta(seconds=60), + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, prefetch_count=5)], + admin_client = sb_admin_client, + duration=args.duration, send_batch_size=5, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_batch_send_and_receive") + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_batch_send_and_receive"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") - -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_slow_send_and_receive(): +def test_stress_queue_slow_send_and_receive(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = StressTestRunner(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], - duration=timedelta(seconds=3501*3), - send_delay=3500, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_slow_send_and_receive") + admin_client = sb_admin_client, + duration=args.duration, + send_delay=(args.duration/3), + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_slow_send_and_receive"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) - + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_receive_and_delete(): +def test_stress_queue_receive_and_delete(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = StressTestRunner(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE)], + admin_client = sb_admin_client, should_complete_messages = False, - duration=timedelta(seconds=60), - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_slow_send_and_receive") + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_slow_send_and_receive"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") - -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_unsettled_messages(): +def test_stress_queue_unsettled_messages(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = StressTestRunner(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], - duration = timedelta(seconds=350), + admin_client = sb_admin_client, + duration=args.duration, should_complete_messages = False, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_unsettled_messages") + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_unsettled_messages"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() # This test is prompted by reports of an issue where enough unsettled messages saturate a service-side cache # and prevent further receipt. - assert(result.total_sent > 2500) - assert(result.total_received > 2500) - + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_receive_large_batch_size(): +def test_stress_queue_receive_large_batch_size(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = StressTestRunner(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, prefetch_count=50)], - duration = timedelta(seconds=60), + admin_client = sb_admin_client, + duration = args.duration, max_message_count = 50, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_receive_large_batch_size") + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_receive_large_batch_size"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") # Cannot be defined at local scope due to pickling into multiproc runner. class ReceiverTimeoutStressTestRunner(StressTestRunner): @@ -151,24 +142,23 @@ def on_send(self, state, sent_message, sender): # To make receive time out, in push mode this delay would trigger receiver reconnection time.sleep(self.max_wait_time + 5) -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_pull_receive_timeout(): +def test_stress_queue_pull_receive_timeout(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = ReceiverTimeoutStressTestRunner( senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], + admin_client = sb_admin_client, max_wait_time = 5, receive_type=ReceiveType.pull, - duration=timedelta(seconds=600), - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_pull_receive_timeout") + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_pull_receive_timeout"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) - + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") class LongRenewStressTestRunner(StressTestRunner): def on_receive(self, state, received_message, receiver): @@ -177,23 +167,22 @@ def on_receive(self, state, received_message, receiver): renewer.register(receiver, received_message, max_lock_renewal_duration=300) time.sleep(300) -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_long_renew_send_and_receive(): +def test_stress_queue_long_renew_send_and_receive(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = LongRenewStressTestRunner( senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], - duration=timedelta(seconds=3000), + admin_client = sb_admin_client, + duration=args.duration, send_delay=300, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_long_renew_send_and_receive") + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_long_renew_send_and_receive"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) - + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") class LongSessionRenewStressTestRunner(StressTestRunner): def on_receive(self, state, received_message, receiver): @@ -203,25 +192,25 @@ def on_fail(renewable, error): print("FAILED AUTOLOCKRENEW: " + str(error)) renewer.register(receiver, receiver.session, max_lock_renewal_duration=600, on_lock_renew_failure=on_fail) -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_long_renew_session_send_and_receive(): +def test_stress_queue_long_renew_session_send_and_receive(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) session_id = 'test_stress_queue_long_renew_send_and_receive' stress_test = LongSessionRenewStressTestRunner( senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, session_id=session_id)], - duration=timedelta(seconds=3000), + admin_client = sb_admin_client, + duration=args.duration, send_delay=300, send_session_id=session_id, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_long_renew_session_send_and_receive") + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_long_renew_session_send_and_receive"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") class Peekon_receiveStressTestRunner(StressTestRunner): @@ -229,22 +218,23 @@ def on_receive_batch(self, state, received_message, receiver): '''Called on every successful receive''' assert receiver.peek_messages()[0] -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_peek_messages(): +def test_stress_queue_peek_messages(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = Peekon_receiveStressTestRunner( senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], - duration = timedelta(seconds=300), + admin_client = sb_admin_client, + duration=args.duration, receive_delay = 30, receive_type = ReceiveType.none, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_peek_messages") + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_peek_messages"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") # TODO: This merits better validation, to be implemented alongside full metric spread. @@ -261,24 +251,23 @@ def on_send(self, state, sent_message, sender): sender.__exit__() sender.__enter__() -@pytest.mark.liveTest -@pytest.mark.live_test_only -@pytest.mark.skip(reason='This test is disabled unless re-openability of handlers is desired and re-enabled') -def test_stress_queue_close_and_reopen(): +def test_stress_queue_close_and_reopen(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = RestartHandlerStressTestRunner( senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], - duration = timedelta(seconds=300), + admin_client = sb_admin_client, + duration = args.duration, receive_delay = 30, send_delay = 10, - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_close_and_reopen") + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_close_and_reopen"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") # This test validates that all individual messages are received contiguously over a long time period. # (e.g. not dropped for whatever reason, not sent, or not received) @@ -310,23 +299,96 @@ def pre_process_message_body(self, payload): return str(body) -@pytest.mark.liveTest -@pytest.mark.live_test_only -def test_stress_queue_check_for_dropped_messages(): +def test_stress_queue_check_for_dropped_messages(args): sb_client = ServiceBusClient.from_connection_string( - SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE) + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) stress_test = DroppedMessageCheckerStressTestRunner( senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME)], + admin_client = sb_admin_client, receive_type=ReceiveType.pull, - duration=timedelta(seconds=3000), - azure_monitor_metric=AzureMonitorMetric("test_stress_queue_check_for_dropped_messages") + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_check_for_dropped_messages"), + logging_level=LOGGING_LEVEL ) result = stress_test.run() - assert(result.total_sent > 0) - assert(result.total_received > 0) + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") if __name__ == '__main__': - #parser = ArgumentParser() - pytest.main() + load_dotenv(dotenv_path=ENV_FILE, override=True) + parser = ArgumentParser() + parser.add_argument("--conn_str", help="ServiceBus connection string", + default=os.environ.get('SERVICE_BUS_CONNECTION_STR')) + parser.add_argument("--queue_name", help="The queue name.", default="testQueue") + parser.add_argument("--method", type=str) + parser.add_argument("--duration", type=int, default=259200) + parser.add_argument("--logging-enable", action="store_true") + parser.add_argument("--print_console", action="store_true") + + parser.add_argument("--send-batch-size", type=int, default=100) + parser.add_argument("--message-size", type=int, default=100) + + parser.add_argument("--receive-type", type=str, default="pull") + parser.add_argument("--max_wait_time", type=int, default=10) + parser.add_argument("--max_message_count", type=int, default=1) + parser.add_argument("--uamqp_mode", action="store_true") + parser.add_argument("--transport", action="store_true") + parser.add_argument("--debug_level", help="Flag for setting a debug level, can be Info, Debug, Warning, Error or Critical", type=str, default="Error") + + args, _ = parser.parse_known_args() + + if args.transport: + TRANSPORT_TYPE= TransportType.AmqpOverWebsocket + else: + TRANSPORT_TYPE= TransportType.Amqp + + SERVICE_BUS_CONNECTION_STR = args.conn_str + SERVICEBUS_QUEUE_NAME= args.queue_name + LOGGING_ENABLE = args.logging_enable + LOGGING_LEVEL = getattr(logging, args.debug_level.upper(), None) + + sb_admin_client = ServiceBusAdministrationClient.from_connection_string(SERVICE_BUS_CONNECTION_STR) + + if args.method == "send_receive": + test_stress_queue_send_and_receive(args) + elif args.method == "send_pull_receive": + test_stress_queue_send_and_pull_receive(args) + elif args.method == "send_receive_batch": + test_stress_queue_batch_send_and_receive(args) + elif args.method == "send_receive_slow": + test_stress_queue_slow_send_and_receive(args) + elif args.method == "receive_delete": + test_stress_queue_receive_and_delete(args) + elif args.method == "unsettled_message": + test_stress_queue_unsettled_messages(args) + elif args.method == "large_batch": + test_stress_queue_receive_large_batch_size(args) + elif args.method == "pull_receive_timeout": + test_stress_queue_pull_receive_timeout(args) + elif args.method == "long_renew": + test_stress_queue_long_renew_send_and_receive(args) + elif args.method == "long_renew_session": + test_stress_queue_long_renew_session_send_and_receive(args) + elif args.method == "queue_peek": + test_stress_queue_peek_messages(args) + elif args.method == "queue_close_reopen": + test_stress_queue_close_and_reopen(args) + elif args.method == "dropped_messages": + test_stress_queue_check_for_dropped_messages(args) + else: + test_stress_queue_send_and_receive(args) + test_stress_queue_send_and_pull_receive(args) + test_stress_queue_batch_send_and_receive(args) + test_stress_queue_slow_send_and_receive(args) + test_stress_queue_receive_and_delete(args) + test_stress_queue_unsettled_messages(args) + test_stress_queue_receive_large_batch_size(args) + test_stress_queue_pull_receive_timeout(args) + test_stress_queue_long_renew_send_and_receive(args) + test_stress_queue_long_renew_session_send_and_receive(args) + test_stress_queue_peek_messages(args) + test_stress_queue_close_and_reopen(args) + test_stress_queue_check_for_dropped_messages(args) + diff --git a/sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues_async.py b/sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues_async.py new file mode 100644 index 0000000000000..eb9f05765cf23 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/stress/scripts/test_stress_queues_async.py @@ -0,0 +1,399 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +from datetime import timedelta +import logging +import time +import os +import asyncio +from dotenv import load_dotenv +from argparse import ArgumentParser + +from azure.servicebus import AutoLockRenewer, TransportType +from azure.servicebus.aio import ServiceBusClient +from azure.servicebus.aio.management import ServiceBusAdministrationClient +from azure.servicebus._common.constants import ServiceBusReceiveMode +from app_insights_metric import AzureMonitorMetric + +from stress_test_base import StressTestRunnerAsync, ReceiveType + +ENV_FILE = os.environ.get('ENV_FILE') + + +async def test_stress_queue_send_and_receive(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = StressTestRunnerAsync(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_send_and_receive"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +async def test_stress_queue_send_and_pull_receive(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = StressTestRunnerAsync(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + receive_type=ReceiveType.pull, + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_send_and_pull_receive"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +async def test_stress_queue_batch_send_and_receive(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = StressTestRunnerAsync(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, prefetch_count=5, max_wait_time=10)], + admin_client = sb_admin_client, + duration=args.duration, + send_batch_size=5, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_batch_send_and_receive"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +async def test_stress_queue_slow_send_and_receive(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = StressTestRunnerAsync(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + duration=args.duration, + send_delay=(args.duration/3), + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_slow_send_and_receive"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +async def test_stress_queue_receive_and_delete(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = StressTestRunnerAsync(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, max_wait_time=10)], + admin_client = sb_admin_client, + should_complete_messages = False, + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_slow_send_and_receive"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +async def test_stress_queue_unsettled_messages(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = StressTestRunnerAsync(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + duration=args.duration, + should_complete_messages = False, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_unsettled_messages"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + # This test is prompted by reports of an issue where enough unsettled messages saturate a service-side cache + # and prevent further receipt. + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +async def test_stress_queue_receive_large_batch_size(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = StressTestRunnerAsync(senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, prefetch_count=50, max_wait_time=10)], + admin_client = sb_admin_client, + duration = args.duration, + max_message_count = 50, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_receive_large_batch_size"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +# Cannot be async defined at local scope due to pickling into multiproc runner. +class ReceiverTimeoutStressTestRunner(StressTestRunnerAsync): + def on_send(self, state, sent_message, sender): + '''Called on every successful send''' + if state.total_sent % 10 == 0: + # To make receive time out, in push mode this delay would trigger receiver reconnection + time.sleep(self.max_wait_time + 5) + +async def test_stress_queue_pull_receive_timeout(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = ReceiverTimeoutStressTestRunner( + senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + max_wait_time = 5, + receive_type=ReceiveType.pull, + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_pull_receive_timeout"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +class LongRenewStressTestRunner(StressTestRunnerAsync): + def on_receive(self, state, received_message, receiver): + '''Called on every successful receive''' + renewer = AutoLockRenewer() + renewer.register(receiver, received_message, max_lock_renewal_duration=300) + time.sleep(300) + +async def test_stress_queue_long_renew_send_and_receive(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = LongRenewStressTestRunner( + senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + duration=args.duration, + send_delay=300, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_long_renew_send_and_receive"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +class LongSessionRenewStressTestRunner(StressTestRunnerAsync): + def on_receive(self, state, received_message, receiver): + '''Called on every successful receive''' + renewer = AutoLockRenewer() + def on_fail(renewable, error): + print("FAILED AUTOLOCKRENEW: " + str(error)) + renewer.register(receiver, receiver.session, max_lock_renewal_duration=600, on_lock_renew_failure=on_fail) + +async def test_stress_queue_long_renew_session_send_and_receive(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + session_id = 'test_stress_queue_long_renew_send_and_receive' + + stress_test = LongSessionRenewStressTestRunner( + senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, session_id=session_id, max_wait_time=10)], + admin_client = sb_admin_client, + duration=args.duration, + send_delay=300, + send_session_id=session_id, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_long_renew_session_send_and_receive"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +class Peekon_receiveStressTestRunner(StressTestRunnerAsync): + def on_receive_batch(self, state, received_message, receiver): + '''Called on every successful receive''' + assert receiver.peek_messages()[0] + +async def test_stress_queue_peek_messages(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = Peekon_receiveStressTestRunner( + senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + duration=args.duration, + receive_delay = 30, + receive_type = ReceiveType.none, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_peek_messages"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + # TODO: This merits better validation, to be implemented alongside full metric spread. + + +class RestartHandlerStressTestRunner(StressTestRunnerAsync): + def post_receive(self, state, receiver): + '''Called after completion of every successful receive''' + if state.total_received % 3 == 0: + receiver.__exit__() + receiver.__enter__() + + def on_send(self, state, sent_message, sender): + '''Called after completion of every successful receive''' + if state.total_sent % 3 == 0: + sender.__exit__() + sender.__enter__() + +async def test_stress_queue_close_and_reopen(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = RestartHandlerStressTestRunner( + senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + duration = args.duration, + receive_delay = 30, + send_delay = 10, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_close_and_reopen"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +# This test validates that all individual messages are received contiguously over a long time period. +# (e.g. not dropped for whatever reason, not sent, or not received) +class DroppedMessageCheckerStressTestRunner(StressTestRunnerAsync): + def on_receive(self, state, received_message, receiver): + '''Called on every successful receive''' + last_seen = getattr(state, 'last_seen', -1) + noncontiguous = getattr(state, 'noncontiguous', set()) + body = int(str(received_message)) + if body == last_seen+1: + last_seen += 1 + if noncontiguous: + while (last_seen+1) in noncontiguous: + last_seen += 1 + noncontiguous.remove(last_seen) + else: + noncontiguous.add(body) + state.noncontiguous = noncontiguous + state.last_seen = last_seen + + def pre_process_message_body(self, payload): + '''Called when constructing message body''' + try: + body = self._message_id + except: + _message_id = 0 + body = 0 + _message_id += 1 + + return str(body) + +async def test_stress_queue_check_for_dropped_messages(args): + sb_client = ServiceBusClient.from_connection_string( + SERVICE_BUS_CONNECTION_STR, logging_enable=LOGGING_ENABLE, transport_type=TRANSPORT_TYPE) + stress_test = DroppedMessageCheckerStressTestRunner( + senders = [sb_client.get_queue_sender(SERVICEBUS_QUEUE_NAME)], + receivers = [sb_client.get_queue_receiver(SERVICEBUS_QUEUE_NAME, max_wait_time=10)], + admin_client = sb_admin_client, + receive_type=ReceiveType.pull, + duration=args.duration, + azure_monitor_metric=AzureMonitorMetric("test_stress_queue_check_for_dropped_messages"), + logging_level=LOGGING_LEVEL + ) + + result = await stress_test.run_async() + print(f"Total send {result.total_sent}") + print(f"Total received {result.total_received}") + +async def run(args): + if args.method == "send_receive": + await test_stress_queue_send_and_receive(args) + elif args.method == "send_pull_receive": + await test_stress_queue_send_and_pull_receive(args) + elif args.method == "send_receive_batch": + await test_stress_queue_batch_send_and_receive(args) + elif args.method == "send_receive_slow": + await test_stress_queue_slow_send_and_receive(args) + elif args.method == "receive_delete": + await test_stress_queue_receive_and_delete(args) + elif args.method == "unsettled_message": + await test_stress_queue_unsettled_messages(args) + elif args.method == "large_batch": + await test_stress_queue_receive_large_batch_size(args) + elif args.method == "pull_receive_timeout": + await test_stress_queue_pull_receive_timeout(args) + elif args.method == "long_renew": + await test_stress_queue_long_renew_send_and_receive(args) + elif args.method == "long_renew_session": + await test_stress_queue_long_renew_session_send_and_receive(args) + elif args.method == "queue_peek": + await test_stress_queue_peek_messages(args) + elif args.method == "queue_close_reopen": + await test_stress_queue_close_and_reopen(args) + elif args.method == "dropped_messages": + await test_stress_queue_check_for_dropped_messages(args) + else: + await test_stress_queue_send_and_receive(args) + await test_stress_queue_send_and_pull_receive(args) + await test_stress_queue_batch_send_and_receive(args) + await test_stress_queue_slow_send_and_receive(args) + await test_stress_queue_receive_and_delete(args) + await test_stress_queue_unsettled_messages(args) + await test_stress_queue_receive_large_batch_size(args) + await test_stress_queue_pull_receive_timeout(args) + await test_stress_queue_long_renew_send_and_receive(args) + await test_stress_queue_long_renew_session_send_and_receive(args) + await test_stress_queue_peek_messages(args) + await test_stress_queue_close_and_reopen(args) + await test_stress_queue_check_for_dropped_messages(args) + + +if __name__ == '__main__': + load_dotenv(dotenv_path=ENV_FILE, override=True) + parser = ArgumentParser() + parser.add_argument("--conn_str", help="ServiceBus connection string", + default=os.environ.get('SERVICE_BUS_CONNECTION_STR')) + parser.add_argument("--queue_name", help="The queue name.", default='testQueue') + parser.add_argument("--method", type=str) + parser.add_argument("--duration", type=int, default=259200) + parser.add_argument("--logging-enable", action="store_true") + parser.add_argument("--print_console", action="store_true") + + parser.add_argument("--send-batch-size", type=int, default=100) + parser.add_argument("--message-size", type=int, default=100) + + parser.add_argument("--receive-type", type=str, default="pull") + parser.add_argument("--max_wait_time", type=int, default=10) + parser.add_argument("--max_message_count", type=int, default=1) + parser.add_argument("--uamqp_mode", type=bool, default=False) + parser.add_argument("--transport", action="store_true") + parser.add_argument("--debug_level", help="Flag for setting a debug level, can be Info, Debug, Warning, Error or Critical", type=str, default="Error") + args, _ = parser.parse_known_args() + + if args.transport: + TRANSPORT_TYPE= TransportType.AmqpOverWebsocket + else: + TRANSPORT_TYPE= TransportType.Amqp + + SERVICE_BUS_CONNECTION_STR = args.conn_str + SERVICEBUS_QUEUE_NAME= args.queue_name + LOGGING_ENABLE = args.logging_enable + LOGGING_LEVEL = getattr(logging, args.debug_level.upper(), None) + + sb_admin_client = ServiceBusAdministrationClient.from_connection_string(SERVICE_BUS_CONNECTION_STR) + loop = asyncio.get_event_loop() + loop.run_until_complete(run(args)) + + diff --git a/sdk/servicebus/azure-servicebus/stress/stress-test-resources.bicep b/sdk/servicebus/azure-servicebus/stress/stress-test-resources.bicep index 0cb2d7a3e3353..2bc46e3096994 100644 --- a/sdk/servicebus/azure-servicebus/stress/stress-test-resources.bicep +++ b/sdk/servicebus/azure-servicebus/stress/stress-test-resources.bicep @@ -10,6 +10,8 @@ var authorizationRuleName_var = '${baseName}/RootManageSharedAccessKey' var authorizationRuleNameNoManage_var = '${baseName}/NoManage' var serviceBusDataOwnerRoleId = '/subscriptions/${subscription().subscriptionId}/providers/Microsoft.Authorization/roleDefinitions/090c5cfd-751d-490a-894a-3ce6f1109419' +var sbPremiumName = 'sb-premium-${baseName}' + resource servicebus 'Microsoft.ServiceBus/namespaces@2018-01-01-preview' = { name: baseName location: location @@ -22,6 +24,16 @@ resource servicebus 'Microsoft.ServiceBus/namespaces@2018-01-01-preview' = { } } +resource servicebusPremium 'Microsoft.ServiceBus/namespaces@2018-01-01-preview' = { + name: sbPremiumName + location: location + sku: { + name: 'Premium' + tier: 'Premium' + } +} + + resource authorizationRuleName 'Microsoft.ServiceBus/namespaces/AuthorizationRules@2015-08-01' = { name: authorizationRuleName_var location: location @@ -82,26 +94,27 @@ resource testQueue 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = { } } -//resource testQueueWithSessions 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = { -// parent: servicebus -// name: 'testQueueWithSessions' -// properties: { -// lockDuration: 'PT5M' -// maxSizeInMegabytes: 1024 -// requiresDuplicateDetection: false -// requiresSession: true -// defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S' -// deadLetteringOnMessageExpiration: false -// duplicateDetectionHistoryTimeWindow: 'PT10M' -// maxDeliveryCount: 10 -// autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S' -// enablePartitioning: false -// enableExpress: false -// } -//} +resource testQueueWithSessions 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = { + parent: servicebus + name: 'testQueueWithSessions' + properties: { + lockDuration: 'PT5M' + maxSizeInMegabytes: 1024 + requiresDuplicateDetection: false + requiresSession: true + defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S' + deadLetteringOnMessageExpiration: false + duplicateDetectionHistoryTimeWindow: 'PT10M' + maxDeliveryCount: 10 + autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S' + enablePartitioning: false + enableExpress: false + } +} -output SERVICE_BUS_CONNECTION_STR string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'RootManageSharedAccessKey'), apiVersion).primaryConnectionString -output SERVICE_BUS_QUEUE_NAME string = 'testQueue' -//output QUEUE_NAME_WITH_SESSIONS string = 'testQueueWithSessions' -//output SERVICE_BUS_CONNECTION_STRING_NO_MANAGE string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'NoManage'), apiVersion).primaryConnectionString -//output SERVICE_BUS_ENDPOINT string = replace(servicebus.properties.serviceBusEndpoint, ':443/', '') +output SERVICEBUS_CONNECTION_STRING string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'RootManageSharedAccessKey'), apiVersion).primaryConnectionString +output SERVICEBUS_CONNECTION_STRING_NO_MANAGE string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'NoManage'), apiVersion).primaryConnectionString +output SERVICEBUS_CONNECTION_STRING_PREMIUM string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', sbPremiumName, 'RootManageSharedAccessKey'), apiVersion).primaryConnectionString +output SERVICEBUS_ENDPOINT string = replace(replace(servicebus.properties.serviceBusEndpoint, ':443/', ''), 'https://', '') +output QUEUE_NAME string = 'testQueue' +output QUEUE_NAME_WITH_SESSIONS string = 'testQueueWithSessions' diff --git a/sdk/servicebus/azure-servicebus/stress/templates/network_loss.yaml b/sdk/servicebus/azure-servicebus/stress/templates/network_loss.yaml deleted file mode 100644 index a1e64b80ab9c6..0000000000000 --- a/sdk/servicebus/azure-servicebus/stress/templates/network_loss.yaml +++ /dev/null @@ -1,25 +0,0 @@ -{{- include "stress-test-addons.chaos-wrapper.tpl" (list . "stress.python-sb-network") -}} -{{- define "stress.python-sb-network" -}} -apiVersion: chaos-mesh.org/v1alpha1 -kind: NetworkChaos -spec: - scheduler: - cron: '@every 30s' - duration: '10s' - action: loss - direction: to - externalTargets: - - '{{ .Stress.ResourceGroupName }}.servicebus.windows.net' - mode: one - selector: - labelSelectors: - testInstance: "servicebus-{{ .Release.Name }}-{{ .Release.Revision }}" - chaos: 'true' - namespaces: - - {{ .Release.Namespace }} - podPhaseSelectors: - - 'Running' - loss: - loss: '100' - correlation: '100' -{{- end -}} \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/stress/templates/testjob.yaml b/sdk/servicebus/azure-servicebus/stress/templates/testjob.yaml index 8bc542b855d0a..9d46259360de9 100644 --- a/sdk/servicebus/azure-servicebus/stress/templates/testjob.yaml +++ b/sdk/servicebus/azure-servicebus/stress/templates/testjob.yaml @@ -2,9 +2,9 @@ {{- define "stress.python-sb-stress" -}} metadata: labels: - testName: "deploy-python-sb-stress" - testInstance: "servicebus-{{ .Release.Name }}-{{ .Release.Revision }}" - chaos: "true" + testName: "py-sb-stress" + testInstance: "sb-{{ .Release.Name }}-{{ .Release.Revision }}" + chaos: "{{ default false .Stress.chaos }}" spec: containers: - name: python-sb-stress @@ -15,9 +15,63 @@ spec: memory: "2000Mi" cpu: "1" - {{ if eq .Stress.testTarget "sbStress" }} - command: ['bash', '-c', 'python3 test_stress_queues.py'] + {{ if eq .Stress.testTarget "aqueuew" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output && python test_stress_queues.py --method send_receive --duration 300000 --logging-enable --transport'] {{- end -}} - + + {{ if eq .Stress.testTarget "queuew" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output && python test_stress_queues.py --method send_receive --duration 300000 --logging-enable --transport'] + {{- end -}} + + {{ if eq .Stress.testTarget "aqueuepullw" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues_async.py --method send_pull_receive --duration 300000 --logging-enable --transport'] + {{- end -}} + + {{ if eq .Stress.testTarget "queuepullw" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues.py --method send_pull_receive --duration 300000 --logging-enable --transport'] + {{- end -}} + + {{ if eq .Stress.testTarget "abatchw" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues_async.py --method send_receive_batch --duration 300000 --logging-enable --transport'] + {{- end -}} + + {{ if eq .Stress.testTarget "batchw" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues.py --method send_receive_batch --duration 300000 --logging-enable --transport'] + {{- end -}} + + {{ if eq .Stress.testTarget "aqueue" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues_async.py --method send_receive --duration 300000 --logging-enable'] + {{- end -}} + + {{ if eq .Stress.testTarget "queue" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues.py --method send_receive --duration 300000 --logging-enable'] + {{- end -}} + + {{ if eq .Stress.testTarget "aqueuepull" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues_async.py --method send_pull_receive --duration 300000 --logging-enable'] + {{- end -}} + + {{ if eq .Stress.testTarget "queuepull" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues.py --method send_pull_receive --duration 300000 --logging-enable --output $DEBUG_SHARE/output.bin'] + {{- end -}} + + {{ if eq .Stress.testTarget "abatch" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues_async.py --method send_receive_batch --duration 300000 --logging-enable --output $DEBUG_SHARE/output.bin'] + {{- end -}} + + {{ if eq .Stress.testTarget "batch" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && cat > $DEBUG_SHARE/output.bin && python test_stress_queues.py --method send_receive_batch --duration 300000 --logging-enable --output $DEBUG_SHARE/output.bin'] + {{- end -}} + + {{ if eq .Stress.testTarget "amemray" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && memray run --output $DEBUG_SHARE/sb_async_memray_output.bin test_stress_queues_async.py --method send_pull_receive --duration 300000 --logging-enable'] + {{- end -}} + + {{ if eq .Stress.testTarget "memray" }} + command: ['bash', '-c', 'mkdir -p $DEBUG_SHARE && memray run --output $DEBUG_SHARE/sb_memray_output.bin test_stress_queues.py --method send_pull_receive --duration 300000 --logging-enable'] + {{- end -}} + {{- include "stress-test-addons.container-env" . | nindent 6 }} {{- end -}} + + diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py index d86ae08a4c4aa..d0da99f9156c9 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py @@ -5,6 +5,7 @@ #-------------------------------------------------------------------------- import asyncio +import json import logging import sys import os @@ -12,11 +13,19 @@ import pytest import time import uuid +import pickle from datetime import datetime, timedelta -import uamqp -import uamqp.errors -from uamqp import compat +try: + import uamqp + from azure.servicebus.aio._transport._uamqp_transport_async import UamqpTransportAsync +except ImportError: + uamqp = None + +try: + from azure.servicebus.aio._transport._pyamqp_transport_async import PyamqpTransportAsync +except: + PyamqpTransportAsync = None from azure.servicebus.aio import ( ServiceBusClient, AutoLockRenewer @@ -36,6 +45,9 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) +from azure.servicebus._pyamqp.message import Message +from azure.servicebus._pyamqp import error, management_operation +from azure.servicebus._pyamqp.aio import AMQPClientAsync, ReceiveClientAsync, _management_operation_async from azure.servicebus._common.constants import ServiceBusReceiveMode, ServiceBusSubQueue from azure.servicebus._common.utils import utc_now from azure.servicebus.management._models import DictMixin @@ -48,7 +60,7 @@ MessageSizeExceededError, OperationTimeoutError ) -from devtools_testutils import AzureMgmtTestCase, AzureTestCase +from devtools_testutils import AzureMgmtRecordedTestCase, AzureTestCase from servicebus_preparer import ( CachedServiceBusNamespacePreparer, CachedServiceBusQueuePreparer, @@ -57,11 +69,14 @@ ) from utilities import get_logger, print_message, sleep_until_expired from mocks_async import MockReceivedMessage, MockReceiver +from utilities import get_logger, print_message, sleep_until_expired, uamqp_transport as get_uamqp_transport, ArgPasserAsync + +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() _logger = get_logger(logging.DEBUG) -class ServiceBusQueueAsyncTests(AzureMgmtTestCase): +class TestServiceBusQueueAsync(AzureMgmtRecordedTestCase): @pytest.mark.asyncio @pytest.mark.liveTest @@ -69,9 +84,11 @@ class ServiceBusQueueAsyncTests(AzureMgmtTestCase): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - async def test_async_queue_by_queue_client_conn_str_receive_handler_peeklock(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_queue_client_conn_str_receive_handler_peeklock(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) async with sender: @@ -108,9 +125,6 @@ async def test_async_queue_by_queue_client_conn_str_receive_handler_peeklock(sel with pytest.raises(ValueError): await receiver.receive_messages(max_wait_time=0) - with pytest.raises(ValueError): - await receiver._get_streaming_message_iter(max_wait_time=0) - count = 0 async for message in receiver: print_message(_logger, message) @@ -129,15 +143,18 @@ async def test_async_queue_by_queue_client_conn_str_receive_handler_peeklock(sel with pytest.raises(ValueError): await receiver.peek_messages() + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - async def test_async_queue_by_queue_client_conn_str_receive_handler_release_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_queue_client_conn_str_receive_handler_release_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async def sub_test_releasing_messages(): # test releasing messages when prefetch is 1 and link credits are issue dynamically @@ -242,16 +259,26 @@ async def sub_test_non_releasing_messages(): receiver = sb_client.get_queue_receiver(servicebus_queue.name) sender = sb_client.get_queue_sender(servicebus_queue.name) - def _hack_disable_receive_context_message_received(self, message): - # pylint: disable=protected-access - self._handler._was_message_received = True - self._handler._received_messages.put(message) + if uamqp_transport: + def _hack_disable_receive_context_message_received(self, message): + # pylint: disable=protected-access + self._handler._was_message_received = True + self._handler._received_messages.put(message) + else: + def _hack_disable_receive_context_message_received(self, frame, message): + # pylint: disable=protected-access + self._handler._last_activity_timestamp = time.time() + self._handler._received_messages.put((frame, message)) async with sender, receiver: # send 5 msgs to queue first await sender.send_messages([ServiceBusMessage('test') for _ in range(5)]) - receiver._handler.message_handler.on_message_received = types.MethodType( - _hack_disable_receive_context_message_received, receiver) + if uamqp_transport: + receiver._handler.message_handler.on_message_received = types.MethodType( + _hack_disable_receive_context_message_received, receiver) + else: + receiver._handler._link._on_transfer = types.MethodType( + _hack_disable_receive_context_message_received, receiver) received_msgs = [] while len(received_msgs) < 5: # issue 10 link credits, client should consume 5 msgs from the service @@ -290,23 +317,25 @@ def _hack_disable_receive_context_message_received(self, message): await sub_test_releasing_messages_iterator() await sub_test_non_releasing_messages() + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_queue_client_send_multiple_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_queue_client_send_multiple_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) - messages = [] - for i in range(10): - message = ServiceBusMessage("Handler message no. {}".format(i)) - messages.append(message) - await sender.send_messages(messages) - assert sender._handler._msg_timeout == 0 - await sender.close() + async with sender: + messages = [] + for i in range(10): + message = ServiceBusMessage("Handler message no. {}".format(i)) + messages.append(message) + await sender.send_messages(messages) with pytest.raises(ValueError): async with sender: @@ -339,15 +368,72 @@ async def test_async_queue_by_queue_client_send_multiple_messages(self, serviceb with pytest.raises(ValueError): await receiver.peek_messages() + sender = sb_client.get_queue_sender(servicebus_queue.name) + receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) + async with sender, receiver: + # send previously unpicklable message + msg = { + "body":"W1tdLCB7ImlucHV0X2lkIjogNH0sIHsiY2FsbGJhY2tzIjogbnVsbCwgImVycmJhY2tzIjogbnVsbCwgImNoYWluIjogbnVsbCwgImNob3JkIjogbnVsbH1d", + "content-encoding":"utf-8", + "content-type":"application/json", + "headers":{ + "lang":"py", + "task":"tasks.example_task", + "id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "shadow":None, + "eta":"2021-10-07T02:30:23.764066+00:00", + "expires":None, + "group":None, + "group_index":None, + "retries":1, + "timelimit":[ + None, + None + ], + "root_id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "parent_id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "argsrepr":"()", + "kwargsrepr":"{'input_id': 4}", + "origin":"gen36@94713e01a9c0", + "ignore_result":1, + "x_correlator":"44a1978d-c869-4173-afe4-da741f0edfb9" + }, + "properties":{ + "correlation_id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "reply_to":"7b9a3672-2fed-3e9b-8bfd-23ae2397d9ad", + "origin":"gen68@c33d4eef123a", + "delivery_mode":2, + "delivery_info":{ + "exchange":"", + "routing_key":"celery_task_queue" + }, + "priority":0, + "body_encoding":"base64", + "delivery_tag":"dc83ddb6-8cdc-4413-b88a-06c56cbde90d" + } + } + await sender.send_messages(ServiceBusMessage(json.dumps(msg))) + messages = await receiver.receive_messages(max_wait_time=10, max_message_count=1) + if not uamqp_transport: + pickled = pickle.loads(pickle.dumps(messages[0])) + assert json.loads(str(pickled)) == json.loads(str(messages[0])) + await receiver.complete_message(pickled) + else: + with pytest.raises(TypeError): + pickled = pickle.loads(pickle.dumps(messages[0])) + await receiver.complete_message(messages[0]) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_github_issue_7079_async(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_github_issue_7079_async(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(5): @@ -359,6 +445,7 @@ async def test_github_issue_7079_async(self, servicebus_namespace_connection_str _logger.debug(message) count += 1 assert count == 5 + @pytest.mark.asyncio @pytest.mark.liveTest @@ -366,9 +453,11 @@ async def test_github_issue_7079_async(self, servicebus_namespace_connection_str @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_github_issue_6178_async(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_github_issue_6178_async(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(3): @@ -382,15 +471,18 @@ async def test_github_issue_6178_async(self, servicebus_namespace_connection_str await receiver.complete_message(message) await asyncio.sleep(40) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - async def test_async_queue_by_queue_client_conn_str_receive_handler_receiveanddelete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_queue_client_conn_str_receive_handler_receiveanddelete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -417,15 +509,18 @@ async def test_async_queue_by_queue_client_conn_str_receive_handler_receiveandde messages.append(message) assert len(messages) == 0 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_queue_client_conn_str_receive_handler_with_stop(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_queue_client_conn_str_receive_handler_with_stop(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -453,15 +548,18 @@ async def test_async_queue_by_queue_client_conn_str_receive_handler_with_stop(se assert not receiver._running assert len(messages) == 6 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_iter_messages_simple(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_iter_messages_simple(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10, receive_mode=ServiceBusReceiveMode.PEEK_LOCK) as receiver: @@ -485,15 +583,18 @@ async def test_async_queue_by_servicebus_client_iter_messages_simple(self, servi assert count == 10 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_conn_str_client_iter_messages_with_abandon(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_conn_str_client_iter_messages_with_abandon(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10, receive_mode=ServiceBusReceiveMode.PEEK_LOCK) as receiver: @@ -522,15 +623,18 @@ async def test_async_queue_by_servicebus_conn_str_client_iter_messages_with_aban count += 1 assert count == 0 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_iter_messages_with_defer(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_iter_messages_with_defer(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10, receive_mode=ServiceBusReceiveMode.PEEK_LOCK) as receiver: @@ -556,15 +660,18 @@ async def test_async_queue_by_servicebus_client_iter_messages_with_defer(self, s count += 1 assert count == 0 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10, receive_mode=ServiceBusReceiveMode.PEEK_LOCK) as receiver: @@ -592,15 +699,18 @@ async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_defe with pytest.raises(ServiceBusError): await receiver.receive_deferred_messages(deferred_messages) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -629,15 +739,18 @@ async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_defe await receiver.renew_message_lock(message) await receiver.complete_message(message) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -675,15 +788,18 @@ async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_defe await receiver.complete_message(message) assert count == 10 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -709,15 +825,18 @@ async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_defe with pytest.raises(ServiceBusError): deferred = await receiver.receive_deferred_messages(deferred_messages) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_not_found(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_not_found(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10, receive_mode=ServiceBusReceiveMode.PEEK_LOCK) as receiver: @@ -743,15 +862,18 @@ async def test_async_queue_by_servicebus_client_iter_messages_with_retrieve_defe with pytest.raises(ServiceBusError): deferred = await receiver.receive_deferred_messages([5, 6, 7]) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_receive_batch_with_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_receive_batch_with_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10, receive_mode=ServiceBusReceiveMode.PEEK_LOCK, prefetch_count=10) as receiver: @@ -794,15 +916,18 @@ async def test_async_queue_by_servicebus_client_receive_batch_with_deadletter(se assert message.application_properties[b'DeadLetterErrorDescription'] == b'Testing description' assert count == 10 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_receive_batch_with_retrieve_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_receive_batch_with_retrieve_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, receive_mode=ServiceBusReceiveMode.PEEK_LOCK, prefetch_count=10) as receiver: @@ -845,25 +970,29 @@ async def test_async_queue_by_servicebus_client_receive_batch_with_retrieve_dead @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_session_fail(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_session_fail(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with pytest.raises(ServiceBusError): await sb_client.get_queue_receiver(servicebus_queue.name, session_id="test")._open_with_retry() async with sb_client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("test session sender", session_id="test")) - + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_browse_messages_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_browse_messages_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(5): @@ -885,9 +1014,11 @@ async def test_async_queue_by_servicebus_client_browse_messages_client(self, ser @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_browse_messages_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_browse_messages_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, receive_mode=ServiceBusReceiveMode.PEEK_LOCK) as receiver: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -909,9 +1040,11 @@ async def test_async_queue_by_servicebus_client_browse_messages_with_receiver(se @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_browse_empty_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_browse_empty_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, receive_mode=ServiceBusReceiveMode.PEEK_LOCK, prefetch_count=10) as receiver: messages = await receiver.peek_messages(10) @@ -923,9 +1056,11 @@ async def test_async_queue_by_servicebus_client_browse_empty_messages(self, serv @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_renew_message_locks(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_renew_message_locks(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: messages = [] locks = 3 @@ -956,15 +1091,18 @@ async def test_async_queue_by_servicebus_client_renew_message_locks(self, servic with pytest.raises(ServiceBusError): await receiver.complete_message(messages[2]) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT5S') - async def test_async_queue_by_queue_client_conn_str_receive_handler_with_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_queue_client_conn_str_receive_handler_with_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -1011,15 +1149,18 @@ async def test_async_queue_by_queue_client_conn_str_receive_handler_with_autoloc await renewer.close() assert len(messages) == 11 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT5S') - async def test_async_queue_by_queue_client_conn_str_receive_handler_with_auto_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_queue_client_conn_str_receive_handler_with_auto_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: # The 10 iterations is "important" because it gives time for the timed out message to be received again. @@ -1071,9 +1212,11 @@ async def test_async_queue_by_queue_client_conn_str_receive_handler_with_auto_au @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_by_servicebus_client_fail_send_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_fail_send_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: too_large = "A" * 1024 * 256 @@ -1085,15 +1228,18 @@ async def test_async_queue_by_servicebus_client_fail_send_messages(self, service with pytest.raises(MessageSizeExceededError): await sender.send_messages([ServiceBusMessage(half_too_large), ServiceBusMessage(half_too_large)]) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_message_time_to_live(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_message_time_to_live(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1118,15 +1264,18 @@ async def test_async_queue_message_time_to_live(self, servicebus_namespace_conne count += 1 assert count == 1 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_duplicate_detection=True, dead_lettering_on_message_expiration=True) - async def test_async_queue_message_duplicate_detection(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_message_duplicate_detection(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: message_id = uuid.uuid4() @@ -1151,9 +1300,11 @@ async def test_async_queue_message_duplicate_detection(self, servicebus_namespac @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_message_connection_closed(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_message_connection_closed(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1173,9 +1324,11 @@ async def test_async_queue_message_connection_closed(self, servicebus_namespace_ @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_message_expiry(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_message_expiry(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1205,9 +1358,11 @@ async def test_async_queue_message_expiry(self, servicebus_namespace_connection_ @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_message_lock_renew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_message_lock_renew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1235,9 +1390,19 @@ async def test_async_queue_message_lock_renew(self, servicebus_namespace_connect @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - async def test_async_queue_message_receive_and_delete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_message_receive_and_delete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp: + transport_type = uamqp.constants.TransportType.Amqp + else: + transport_type = TransportType.Amqp async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, + transport_type=transport_type, + logging_enable=False, + uamqp_transport=uamqp_transport + ) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("Receive and delete test") @@ -1272,9 +1437,11 @@ async def test_async_queue_message_receive_and_delete(self, servicebus_namespace @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_message_batch(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_message_batch(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: def message_content(): for i in range(5): @@ -1292,6 +1459,7 @@ def message_content(): yield message async with sb_client.get_queue_sender(servicebus_queue.name) as sender: + # sending manually created message batch (with default pyamqp) should work for both uamqp/pyamqp message = ServiceBusMessageBatch() for each in message_content(): message.add_message(each) @@ -1329,9 +1497,11 @@ def message_content(): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_schedule_message(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_schedule_message(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: scheduled_enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) async with sb_client.get_queue_receiver(servicebus_queue.name) as receiver: @@ -1358,18 +1528,21 @@ async def test_async_queue_schedule_message(self, servicebus_namespace_connectio else: raise Exception("Failed to receive scheduled message.") + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_schedule_multiple_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_schedule_multiple_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: scheduled_enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) messages = [] - receiver = sb_client.get_queue_receiver(servicebus_queue.name, prefetch_count=20) + receiver = sb_client.get_queue_receiver(servicebus_queue.name, prefetch_count=20, max_wait_time=5) sender = sb_client.get_queue_sender(servicebus_queue.name) async with sender, receiver: content = str(uuid.uuid4()) @@ -1383,7 +1556,7 @@ async def test_async_queue_schedule_multiple_messages(self, servicebus_namespace await sender.send_messages([message_a, message_b]) received_messages = [] - async for message in receiver._get_streaming_message_iter(max_wait_time=5): + async for message in receiver: received_messages.append(message) await receiver.complete_message(message) @@ -1404,6 +1577,11 @@ async def test_async_queue_schedule_multiple_messages(self, servicebus_namespace finally: for message in messages: await receiver.complete_message(message) + if not uamqp_transport: + pickled = pickle.loads(pickle.dumps(messages[0])) + assert pickled.message_id == messages[0].message_id + assert pickled.scheduled_enqueue_time_utc == messages[0].scheduled_enqueue_time_utc + assert pickled.scheduled_enqueue_time_utc <= pickled.enqueued_time_utc.replace(microsecond=0) else: raise Exception("Failed to receive scheduled message.") @@ -1413,9 +1591,11 @@ async def test_async_queue_schedule_multiple_messages(self, servicebus_namespace @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_cancel_scheduled_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_cancel_scheduled_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) async with sb_client.get_queue_receiver(servicebus_queue.name) as receiver: @@ -1436,11 +1616,13 @@ async def test_async_queue_cancel_scheduled_messages(self, servicebus_namespace_ @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_queue_message_amqp_over_websocket(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_message_amqp_over_websocket(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, transport_type=TransportType.AmqpOverWebsocket, - logging_enable=False) as sb_client: + logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: assert sender._config.transport_type == TransportType.AmqpOverWebsocket @@ -1452,7 +1634,8 @@ async def test_queue_message_amqp_over_websocket(self, servicebus_namespace_conn messages = await receiver.receive_messages(max_wait_time=5) assert len(messages) == 1 - async def test_queue_message_http_proxy_setting(self): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_queue_message_http_proxy_setting(self, uamqp_transport): mock_conn_str = "Endpoint=sb://mock.servicebus.windows.net/;SharedAccessKeyName=mock;SharedAccessKey=mock" http_proxy = { 'proxy_hostname': '127.0.0.1', @@ -1461,7 +1644,7 @@ async def test_queue_message_http_proxy_setting(self): 'password': '123456' } - sb_client = ServiceBusClient.from_connection_string(mock_conn_str, http_proxy=http_proxy) + sb_client = ServiceBusClient.from_connection_string(mock_conn_str, http_proxy=http_proxy, uamqp_transport=uamqp_transport) assert sb_client._config.http_proxy == http_proxy assert sb_client._config.transport_type == TransportType.AmqpOverWebsocket @@ -1479,10 +1662,12 @@ async def test_queue_message_http_proxy_setting(self): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_link(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_link(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False) as sb_client: + logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("Test") @@ -1490,11 +1675,15 @@ async def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_lin async with sb_client.get_queue_receiver(servicebus_queue.name) as receiver: messages = await receiver.receive_messages(max_wait_time=5) - await receiver._handler.message_handler.destroy_async() # destroy the underlying receiver link + # destroy the underlying receiver link + if uamqp_transport: + await receiver._handler.message_handler.destroy_async() + else: + await receiver._handler._link.detach() assert len(messages) == 1 await receiver.complete_message(messages[0]) - @pytest.mark.asyncio + @AzureTestCase.await_prepared_test async def test_async_queue_mock_auto_lock_renew_callback(self): # A warning to future devs: If the renew period override heuristic in registration # ever changes, it may break this (since it adjusts renew period if it is not short enough) @@ -1577,7 +1766,7 @@ async def callback_mock(renewable, error): assert not results assert not errors - @pytest.mark.asyncio + @AzureTestCase.await_prepared_test async def test_async_queue_mock_no_reusing_auto_lock_renew(self): auto_lock_renew = AutoLockRenewer() auto_lock_renew._renew_period = 1 @@ -1615,10 +1804,12 @@ async def test_async_queue_mock_no_reusing_auto_lock_renew(self): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_queue_receiver_invalid_autolockrenew_mode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_receiver_invalid_autolockrenew_mode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with pytest.raises(ValueError): async with sb_client.get_queue_receiver(servicebus_queue.name, receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, @@ -1632,14 +1823,16 @@ async def test_queue_receiver_invalid_autolockrenew_mode(self, servicebus_namesp @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_receive_batch_without_setting_prefetch(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_receive_batch_without_setting_prefetch(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: def message_content(): for i in range(20): yield ServiceBusMessage( - body="ServiceBusMessage no. {}".format(i), + body=f"ServiceBusMessage no. {i}", subject='1st' ) @@ -1679,16 +1872,19 @@ def message_content(): # Network/server might be unstable making flow control ineffective in the leading rounds of connection iteration assert receive_counter < 10 # Dynamic link credit issuing come info effect + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_receiver_alive_after_timeout(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_receiver_alive_after_timeout(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False) as sb_client: + logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("0") @@ -1698,11 +1894,11 @@ async def test_async_queue_receiver_alive_after_timeout(self, servicebus_namespa messages = [] async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10) as receiver: - async for message in receiver._get_streaming_message_iter(): + async for message in receiver: messages.append(message) break - async for message in receiver._get_streaming_message_iter(): + async for message in receiver: messages.append(message) for message in messages: @@ -1716,9 +1912,9 @@ async def test_async_queue_receiver_alive_after_timeout(self, servicebus_namespa message_3 = ServiceBusMessage("3") await sender.send_messages([message_2, message_3]) - async for message in receiver._get_streaming_message_iter(): + async for message in receiver: messages.append(message) - async for message in receiver._get_streaming_message_iter(): + async for message in receiver: messages.append(message) assert len(messages) == 4 @@ -1731,15 +1927,18 @@ async def test_async_queue_receiver_alive_after_timeout(self, servicebus_namespa messages = await receiver.receive_messages() assert not messages + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT5M') - async def test_queue_receive_keep_conn_alive_async(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_receive_keep_conn_alive_async(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) @@ -1765,63 +1964,23 @@ async def test_queue_receive_keep_conn_alive_async(self, servicebus_namespace_co assert len(messages) == 0 # make sure messages are removed from the queue assert receiver_handler == receiver._handler # make sure no reconnection happened - - @pytest.mark.asyncio - @pytest.mark.liveTest - @pytest.mark.live_test_only - @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') - @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') - @ServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_async_queue_receiver_respects_max_wait_time_overrides(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, - logging_enable=False) as sb_client: - - async with sb_client.get_queue_sender(servicebus_queue.name) as sender: - message = ServiceBusMessage("0") - await sender.send_messages(message) - - messages = [] - async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: - - time_1 = receiver._handler._counter.get_current_ms() - async for message in receiver._get_streaming_message_iter(max_wait_time=10): - messages.append(message) - await receiver.complete_message(message) - - time_2 = receiver._handler._counter.get_current_ms() - async for message in receiver._get_streaming_message_iter(max_wait_time=1): - messages.append(message) - time_3 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=.5) < timedelta(milliseconds=(time_3 - time_2)) <= timedelta(seconds=2) - time_4 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=8) < timedelta(milliseconds=(time_4 - time_3)) <= timedelta(seconds=11) - - async for message in receiver._get_streaming_message_iter(max_wait_time=3): - messages.append(message) - time_5 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=1) < timedelta(milliseconds=(time_5 - time_4)) <= timedelta(seconds=4) - - async for message in receiver: - messages.append(message) - time_6 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=3) < timedelta(milliseconds=(time_6 - time_5)) <= timedelta(seconds=6) - - async for message in receiver._get_streaming_message_iter(): - messages.append(message) - time_7 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=3) < timedelta(milliseconds=(time_7 - time_6)) <= timedelta(seconds=6) - assert len(messages) == 1 - + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_async_queue_send_twice(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_send_twice(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp: + transport_type = uamqp.constants.TransportType.AmqpOverWebsocket + else: + transport_type = TransportType.AmqpOverWebsocket async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, + transport_type=transport_type, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("ServiceBusMessage") @@ -1840,12 +1999,19 @@ async def test_async_queue_send_twice(self, servicebus_namespace_connection_stri # then normal message resending await sender.send_messages(message) await sender.send_messages(message) - messages = [] - async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=20) as receiver: - async for message in receiver: - messages.append(message) - await receiver.complete_message(message) - assert len(messages) == 2 + expected_count = 2 + if not uamqp_transport: + # pyamqp re-send received pickled message + pickled_recvd = pickle.loads(pickle.dumps(messages[0])) + await sender.send_messages(pickled_recvd) + expected_count = 3 + messages = [] + async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=20) as receiver: + # pyamqp receiver should be picklable + async for message in receiver: + messages.append(message) + await receiver.complete_message(message) + assert len(messages) == expected_count @pytest.mark.asyncio @pytest.mark.liveTest @@ -1853,21 +2019,31 @@ async def test_async_queue_send_twice(self, servicebus_namespace_connection_stri @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_send_timeout(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - async def _hack_amqp_sender_run_async(cls): - await asyncio.sleep(6) # sleep until timeout - await cls.message_handler.work_async() - cls._waiting_messages = 0 - cls._pending_messages = cls._filter_pending() - if cls._backoff and not cls._waiting_messages: - _logger.info("Client told to backoff - sleeping for %r seconds", cls._backoff) - await cls._connection.sleep_async(cls._backoff) - cls._backoff = 0 - await cls._connection.work_async() + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_send_timeout(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + async def _hack_amqp_sender_run_async(self, **kwargs): + time.sleep(6) # sleep until timeout + if uamqp_transport: + await self.message_handler.work_async() + self._waiting_messages = 0 + self._pending_messages = self._filter_pending() + if self._backoff and not self._waiting_messages: + _logger.info("Client told to backoff - sleeping for %r seconds", self._backoff) + await self._connection.sleep_async(self._backoff) + self._backoff = 0 + await self._connection.work_async() + else: + try: + await self._link.update_pending_deliveries() + await self._connection.listen(wait=self._socket_timeout, **kwargs) + except ValueError: + self._shutdown = True + return False return True async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: # this one doesn't need to reset the method, as it's hacking the method on the instance sender._handler._client_run_async = types.MethodType(_hack_amqp_sender_run_async, sender._handler) @@ -1880,37 +2056,68 @@ async def _hack_amqp_sender_run_async(cls): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_queue_mgmt_operation_timeout(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - async def hack_mgmt_execute_async(self, operation, op_type, message, timeout=0): - start_time = self._counter.get_current_ms() - operation_id = str(uuid.uuid4()) - self._responses[operation_id] = None - - await asyncio.sleep(6) # sleep until timeout - while not self._responses[operation_id] and not self.mgmt_error: - if timeout > 0: - now = self._counter.get_current_ms() - if (now - start_time) >= timeout: - raise compat.TimeoutException("Failed to receive mgmt response in {}ms".format(timeout)) - await self.connection.work_async() - if self.mgmt_error: - raise self.mgmt_error - response = self._responses.pop(operation_id) - return response - - original_execute_method = uamqp.async_ops.mgmt_operation_async.MgmtOperationAsync.execute_async - # hack the mgmt method on the class, not on an instance, so it needs reset + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_mgmt_operation_timeout(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp_transport: + async def hack_mgmt_execute_async(self, operation, op_type, message, timeout=0): + start_time = self._counter.get_current_ms() + operation_id = str(uuid.uuid4()) + self._responses[operation_id] = None + + await asyncio.sleep(6) # sleep until timeout + while not self._responses[operation_id] and not self.mgmt_error: + if timeout > 0: + now = self._counter.get_current_ms() + if (now - start_time) >= timeout: + raise uamqp.compat.TimeoutException("Failed to receive mgmt response in {}ms".format(timeout)) + await self.connection.work_async() + if self.mgmt_error: + raise self.mgmt_error + response = self._responses.pop(operation_id) + return response + + original_execute_method = uamqp.async_ops.mgmt_operation_async.MgmtOperationAsync.execute_async + # hack the mgmt method on the class, not on an instance, so it needs reset + else: + async def hack_mgmt_execute_async(self, message, operation=None, operation_type=None, timeout=0): + start_time = time.time() + operation_id = str(uuid.uuid4()) + self._responses[operation_id] = None + self._mgmt_error = None + + await asyncio.sleep(6) # sleep until timeout + while not self._responses[operation_id] and not self._mgmt_error: + if timeout and timeout > 0: + now = time.time() + if (now - start_time) >= timeout: + raise TimeoutError("Failed to receive mgmt response in {}ms".format(timeout)) + await self.connection.listen() + if self._mgmt_error: + self._responses.pop(operation_id) + raise self._mgmt_error + response = self._responses.pop(operation_id) + return response + + original_execute_method = _management_operation_async.ManagementOperation.execute + # hack the mgmt method on the class, not on an instance, so it needs reset try: - uamqp.async_ops.mgmt_operation_async.MgmtOperationAsync.execute_async = hack_mgmt_execute_async + if uamqp_transport: + uamqp.async_ops.mgmt_operation_async.MgmtOperationAsync.execute_async = hack_mgmt_execute_async + else: + _management_operation_async.ManagementOperation.execute = hack_mgmt_execute_async async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: with pytest.raises(OperationTimeoutError): scheduled_time_utc = utc_now() + timedelta(seconds=30) await sender.schedule_messages(ServiceBusMessage("ServiceBusMessage to be scheduled"), scheduled_time_utc, timeout=5) finally: # must reset the mgmt execute method, otherwise other test cases would use the hacked execute method, leading to timeout error - uamqp.async_ops.mgmt_operation_async.MgmtOperationAsync.execute_async = original_execute_method + if uamqp_transport: + uamqp.async_ops.mgmt_operation_async.MgmtOperationAsync.execute_async = original_execute_method + else: + _management_operation_async.ManagementOperation.execute = original_execute_method @pytest.mark.asyncio @pytest.mark.liveTest @@ -1918,55 +2125,85 @@ async def hack_mgmt_execute_async(self, operation, op_type, message, timeout=0): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', lock_duration='PT10S') - async def test_async_queue_operation_negative(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - def _hack_amqp_message_complete(cls): - raise RuntimeError() + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_operation_negative(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp_transport: + def _hack_amqp_message_complete(cls): + raise RuntimeError() + + def _hack_amqp_mgmt_request(cls, message, operation, op_type=None, node=None, callback=None, **kwargs): + raise uamqp.errors.AMQPConnectionError() + + def _hack_sb_receiver_settle_message(self, message, settle_operation, dead_letter_reason=None, dead_letter_error_description=None): + raise uamqp.errors.AMQPError() + else: + async def _hack_amqp_message_complete(cls, _, settlement): + if settlement == 'completed': + raise RuntimeError() - async def _hack_amqp_mgmt_request(cls, message, operation, op_type=None, node=None, callback=None, **kwargs): - raise uamqp.errors.AMQPConnectionError() + async def _hack_amqp_mgmt_request(cls, message, operation, op_type=None, node=None, callback=None, **kwargs): + raise error.AMQPConnectionError(error.ErrorCondition.ConnectionCloseForced) - async def _hack_sb_receiver_settle_message(self, settle_operation, dead_letter_reason=None, dead_letter_error_description=None): - raise uamqp.errors.AMQPError() + async def _hack_sb_receiver_settle_message(self, message, settle_operation, dead_letter_reason=None, dead_letter_error_description=None): + raise error.AMQPException(error.ErrorCondition.ClientError) async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10) - async with sender, receiver: - # negative settlement via receiver link - await sender.send_messages(ServiceBusMessage("body"), timeout=5) - message = (await receiver.receive_messages(max_wait_time=10))[0] - message.message.accept = types.MethodType(_hack_amqp_message_complete, message.message) - await receiver.complete_message(message) # settle via mgmt link + if not uamqp_transport: + original_settlement = ReceiveClientAsync.settle_messages_async + try: + async with sender, receiver: + # negative settlement via receiver link + await sender.send_messages(ServiceBusMessage("body"), timeout=5) + message = (await receiver.receive_messages(max_wait_time=10))[0] + if uamqp_transport: + message._message.accept = types.MethodType(_hack_amqp_message_complete, message._message) + else: + ReceiveClientAsync.settle_messages_async = types.MethodType(_hack_amqp_message_complete, receiver._handler) + await receiver.complete_message(message) # settle via mgmt link - origin_amqp_client_mgmt_request_method = uamqp.AMQPClientAsync.mgmt_request_async - try: - uamqp.AMQPClientAsync.mgmt_request_async = _hack_amqp_mgmt_request - with pytest.raises(ServiceBusConnectionError): - receiver._handler.mgmt_request_async = types.MethodType(_hack_amqp_mgmt_request, receiver._handler) - await receiver.peek_messages() - finally: - uamqp.AMQPClientAsync.mgmt_request_async = origin_amqp_client_mgmt_request_method + if uamqp_transport: + amqp_client = uamqp.AMQPClientAsync + else: + amqp_client = AMQPClientAsync + + origin_amqp_client_mgmt_request_method = amqp_client.mgmt_request_async + try: + amqp_client.mgmt_request_async = _hack_amqp_mgmt_request + with pytest.raises(ServiceBusConnectionError): + receiver._handler.mgmt_request_async = types.MethodType(_hack_amqp_mgmt_request, receiver._handler) + await receiver.peek_messages() + finally: + amqp_client.mgmt_request_async = origin_amqp_client_mgmt_request_method - await sender.send_messages(ServiceBusMessage("body"), timeout=5) + await sender.send_messages(ServiceBusMessage("body"), timeout=5) - message = (await receiver.receive_messages(max_wait_time=10))[0] - origin_sb_receiver_settle_message_method = receiver._settle_message - receiver._settle_message = types.MethodType(_hack_sb_receiver_settle_message, receiver) - with pytest.raises(ServiceBusError): - await receiver.complete_message(message) + message = (await receiver.receive_messages(max_wait_time=10))[0] + origin_sb_receiver_settle_message_method = receiver._settle_message + receiver._settle_message = types.MethodType(_hack_sb_receiver_settle_message, receiver) + with pytest.raises(ServiceBusError): + await receiver.complete_message(message) - receiver._settle_message = origin_sb_receiver_settle_message_method - message = (await receiver.receive_messages(max_wait_time=10))[0] - await receiver.complete_message(message) + receiver._settle_message = origin_sb_receiver_settle_message_method + message = (await receiver.receive_messages(max_wait_time=10))[0] + await receiver.complete_message(message) + finally: + if not uamqp_transport: + ReceiveClientAsync.settle_messages_async = original_settlement + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_async_send_message_no_body(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_send_message_no_body(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string) as sb_client: @@ -1978,18 +2215,20 @@ async def test_async_send_message_no_body(self, servicebus_namespace_connection_ message = await receiver.__anext__() assert message.body is None await receiver.complete_message(message) - + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_async_queue_by_servicebus_client_enum_case_sensitivity(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_queue_by_servicebus_client_enum_case_sensitivity(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): # Note: This test is currently intended to enforce case-sensitivity. If we eventually upgrade to the Fancy Enums being used with new autorest, # we may want to tweak this. async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_receiver(servicebus_queue.name, receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE.value, max_wait_time=5) as receiver: @@ -2008,16 +2247,19 @@ async def test_async_queue_by_servicebus_client_enum_case_sensitivity(self, serv sub_queue=str.upper(ServiceBusSubQueue.DEAD_LETTER.value), max_wait_time=5) as receiver: raise Exception("Should not get here, should be case sensitive.") - - @pytest.mark.asyncio + + + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_queue_async_send_dict_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_async_send_dict_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -2043,13 +2285,29 @@ async def test_queue_async_send_dict_messages(self, servicebus_namespace_connect received_messages.append(message) assert len(received_messages) == 6 + batch_message = await sender.create_message_batch(max_size_in_bytes=73) + for _ in range(2): + try: + batch_message.add_message(message_dict) + except ValueError: + break + await sender.send_messages(batch_message) + received_messages = [] + async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: + async for message in receiver: + received_messages.append(message) + assert len(received_messages) == 1 + + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_queue_async_send_mapping_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_async_send_mapping_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): class MappingMessage(DictMixin): def __init__(self, content): self.body = content @@ -2060,7 +2318,7 @@ def __init__(self): self.message_id = 'foo' async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -2097,9 +2355,11 @@ def __init__(self): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_queue_async_send_dict_messages_error_badly_formatted_dicts(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_async_send_dict_messages_error_badly_formatted_dicts(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -2126,10 +2386,12 @@ async def test_queue_async_send_dict_messages_error_badly_formatted_dicts(self, @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_queue_async_send_dict_messages_scheduled(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_async_send_dict_messages_scheduled(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: content = "Test scheduled message" message_id = uuid.uuid4() message_id2 = uuid.uuid4() @@ -2188,10 +2450,12 @@ async def test_queue_async_send_dict_messages_scheduled(self, servicebus_namespa @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_queue_async_send_dict_messages_scheduled_error_badly_formatted_dicts(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_async_send_dict_messages_scheduled_error_badly_formatted_dicts(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: content = "Test scheduled message" message_id = uuid.uuid4() message_id2 = uuid.uuid4() @@ -2206,33 +2470,50 @@ async def test_queue_async_send_dict_messages_scheduled_error_badly_formatted_di with pytest.raises(TypeError): await sender.schedule_messages(list_message_dicts, scheduled_enqueue_time) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_queue_async_receive_iterator_resume_after_link_detach(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - - async def hack_iter_next_mock_error(self): - await self._open() - # when trying to receive the second message (execution_times is 1), raising LinkDetach error to mock 10 mins idle timeout - if self.execution_times == 1: - from uamqp.errors import LinkDetach - from uamqp.constants import ErrorCodes - self.execution_times += 1 - self.error_raised = True - raise LinkDetach(ErrorCodes.LinkDetachForced) - else: - self.execution_times += 1 - if not self._message_iter: - self._message_iter = self._handler.receive_messages_iter_async() - uamqp_message = await self._message_iter.__anext__() - message = self._build_message(uamqp_message) - return message + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_async_receive_iterator_resume_after_link_detach(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + + async def hack_iter_next_mock_error(self, wait_time=None): + try: + self._receive_context.set() + await self._open() + # when trying to receive the second message (execution_times is 1), raising LinkDetach error to mock 10 mins idle timeout + if self.execution_times == 1: + if uamqp_transport: + from uamqp.errors import LinkDetach + from uamqp.constants import ErrorCodes + error = LinkDetach + error_condition = ErrorCodes + else: + from azure.servicebus._pyamqp.error import ErrorCondition, AMQPLinkError + error = AMQPLinkError + error_condition = ErrorCondition + self.execution_times += 1 + self.error_raised = True + raise error(error_condition.LinkDetachForced) + else: + self.execution_times += 1 + if not self._message_iter: + if uamqp_transport: + self._message_iter = self._handler.receive_messages_iter_async() + else: + self._message_iter = await self._handler.receive_messages_iter_async(timeout=wait_time) + amqp_message = await self._message_iter.__anext__() + message = self._build_received_message(amqp_message) + return message + finally: + self._receive_context.clear() async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages( [ServiceBusMessage("test1"), ServiceBusMessage("test2"), ServiceBusMessage("test3")] @@ -2255,10 +2536,12 @@ async def hack_iter_next_mock_error(self): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_queue_async_send_amqp_annotated_message(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_queue_async_send_amqp_annotated_message(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} prop = {"subject": "sequence"} @@ -2296,8 +2579,12 @@ async def test_queue_async_send_amqp_annotated_message(self, servicebus_namespac dict_message = {"body": content} sb_message = ServiceBusMessage(body=content) message_with_ttl = AmqpAnnotatedMessage(data_body=data_body, header=AmqpMessageHeader(time_to_live=60000)) - uamqp_with_ttl = message_with_ttl._to_outgoing_amqp_message() - assert uamqp_with_ttl.properties.absolute_expiry_time == uamqp_with_ttl.properties.creation_time + uamqp_with_ttl.header.time_to_live + if uamqp_transport: + amqp_transport = UamqpTransportAsync + else: + amqp_transport = PyamqpTransportAsync + amqp_with_ttl = amqp_transport.to_outgoing_amqp_message(message_with_ttl) + assert amqp_with_ttl.properties.absolute_expiry_time == amqp_with_ttl.properties.creation_time + amqp_with_ttl.header.ttl recv_data_msg = recv_sequence_msg = recv_value_msg = normal_msg = 0 async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10) as receiver: @@ -2347,16 +2634,17 @@ async def test_queue_async_send_amqp_annotated_message(self, servicebus_namespac assert recv_value_msg == 3 assert normal_msg == 4 - @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_state_scheduled_async(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_state_scheduled_async(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string) as sb_client: + servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) async with sender: @@ -2371,16 +2659,17 @@ async def test_state_scheduled_async(self, servicebus_namespace_connection_strin for msg in messages: assert msg.state == ServiceBusMessageState.SCHEDULED - @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_state_deferred_async(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_state_deferred_async(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string) as sb_client: + servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) async with sender: diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py index d42e1dfd218fd..97bbe4d7832b8 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py @@ -4,12 +4,17 @@ # license information. #-------------------------------------------------------------------------- - import logging +import sys import time +import asyncio import pytest +import hmac +import hashlib +import base64 +from urllib.parse import quote as url_parse_quote -from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, AccessToken from azure.mgmt.servicebus.models import AccessRights from azure.servicebus.aio import ServiceBusClient, ServiceBusSender, ServiceBusReceiver from azure.servicebus import ServiceBusMessage @@ -17,9 +22,10 @@ from azure.servicebus.exceptions import ( ServiceBusError, ServiceBusAuthenticationError, - ServiceBusAuthorizationError + ServiceBusAuthorizationError, + ServiceBusConnectionError ) -from devtools_testutils import AzureMgmtTestCase +from devtools_testutils import AzureMgmtRecordedTestCase from servicebus_preparer import ( CachedServiceBusNamespacePreparer, ServiceBusTopicPreparer, @@ -32,11 +38,13 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX ) -from utilities import get_logger +from utilities import get_logger, uamqp_transport as get_uamqp_transport, ArgPasserAsync + +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() _logger = get_logger(logging.DEBUG) -class ServiceBusClientAsyncTests(AzureMgmtTestCase): +class TestServiceBusClientAsync(AzureMgmtRecordedTestCase): @pytest.mark.asyncio @pytest.mark.liveTest @@ -44,11 +52,13 @@ class ServiceBusClientAsyncTests(AzureMgmtTestCase): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - async def test_sb_client_bad_credentials_async(self, servicebus_namespace, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_sb_client_bad_credentials_async(self, uamqp_transport, *, servicebus_namespace, servicebus_queue, **kwargs): client = ServiceBusClient( fully_qualified_namespace=f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}", credential=ServiceBusSharedKeyCredential('invalid', 'invalid'), - logging_enable=False) + logging_enable=False, uamqp_transport=uamqp_transport) async with client: with pytest.raises(ServiceBusAuthenticationError): async with client.get_queue_sender(servicebus_queue.name) as sender: @@ -57,12 +67,14 @@ async def test_sb_client_bad_credentials_async(self, servicebus_namespace, servi @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only - async def test_sb_client_bad_namespace_async(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_sb_client_bad_namespace_async(self, uamqp_transport, **kwargs): client = ServiceBusClient( fully_qualified_namespace=f'invalid{SERVICEBUS_ENDPOINT_SUFFIX}', credential=ServiceBusSharedKeyCredential('invalid', 'invalid'), - logging_enable=False) + logging_enable=False, uamqp_transport=uamqp_transport) async with client: with pytest.raises(ServiceBusError): async with client.get_queue_sender('invalidqueue') as sender: @@ -73,10 +85,19 @@ async def test_sb_client_bad_namespace_async(self, **kwargs): @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') - async def test_sb_client_bad_entity_async(self): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_sb_client_bad_entity_async(self, uamqp_transport, *, servicebus_namespace_connection_string=None, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) + + async with client: + with pytest.raises(ServiceBusAuthenticationError): + async with client.get_queue_sender("invalid") as sender: + await sender.send_messages(ServiceBusMessage("test")) + fake_str = f"Endpoint=sb://mock{SERVICEBUS_ENDPOINT_SUFFIX}/;" \ f"SharedAccessKeyName=mock;SharedAccessKey=mock;EntityPath=mockentity" - fake_client = ServiceBusClient.from_connection_string(fake_str) + fake_client = ServiceBusClient.from_connection_string(fake_str, uamqp_transport=uamqp_transport) with pytest.raises(ValueError): fake_client.get_queue_sender('queue') @@ -110,8 +131,10 @@ async def test_sb_client_bad_entity_async(self): @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) @ServiceBusNamespaceAuthorizationRulePreparer(name_prefix='servicebustest', access_rights=[AccessRights.listen]) - async def test_sb_client_readonly_credentials(self, servicebus_authorization_rule_connection_string, servicebus_queue, **kwargs): - client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_sb_client_readonly_credentials(self, uamqp_transport, *, servicebus_authorization_rule_connection_string=None, servicebus_queue=None, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string, uamqp_transport=uamqp_transport) async with client: async with client.get_queue_receiver(servicebus_queue.name) as receiver: @@ -128,8 +151,10 @@ async def test_sb_client_readonly_credentials(self, servicebus_authorization_rul @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) @ServiceBusNamespaceAuthorizationRulePreparer(name_prefix='servicebustest', access_rights=[AccessRights.send]) - async def test_sb_client_writeonly_credentials_async(self, servicebus_authorization_rule_connection_string, servicebus_queue, **kwargs): - client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_sb_client_writeonly_credentials_async(self, uamqp_transport, *, servicebus_authorization_rule_connection_string=None, servicebus_queue=None, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string, uamqp_transport=uamqp_transport) async with client: with pytest.raises(ServiceBusError): @@ -150,8 +175,10 @@ async def test_sb_client_writeonly_credentials_async(self, servicebus_authorizat @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) @CachedServiceBusTopicPreparer(name_prefix='servicebustest') @CachedServiceBusSubscriptionPreparer(name_prefix='servicebustest') - async def test_async_sb_client_close_spawned_handlers(self, servicebus_namespace_connection_string, servicebus_queue, servicebus_topic, servicebus_subscription, **kwargs): - client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_sb_client_close_spawned_handlers(self, uamqp_transport, *, servicebus_namespace_connection_string, servicebus_queue, servicebus_topic, servicebus_subscription, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) await client.close() @@ -241,9 +268,11 @@ async def test_async_sb_client_close_spawned_handlers(self, servicebus_namespace @ServiceBusQueuePreparer(name_prefix='servicebustest_qone', parameter_name='wrong_queue', dead_lettering_on_message_expiration=True) @ServiceBusQueuePreparer(name_prefix='servicebustest_qtwo', dead_lettering_on_message_expiration=True) @ServiceBusQueueAuthorizationRulePreparer(name_prefix='servicebustest_qtwo') - async def test_sb_client_incorrect_queue_conn_str_async(self, servicebus_queue_authorization_rule_connection_string, servicebus_queue, wrong_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_sb_client_incorrect_queue_conn_str_async(self, uamqp_transport, *, servicebus_queue_authorization_rule_connection_string, servicebus_queue, wrong_queue, **kwargs): - client = ServiceBusClient.from_connection_string(servicebus_queue_authorization_rule_connection_string) + client = ServiceBusClient.from_connection_string(servicebus_queue_authorization_rule_connection_string, uamqp_transport=uamqp_transport) async with client: # Validate that the wrong sender/receiver queues with the right credentials fail. with pytest.raises(ValueError): @@ -294,12 +323,16 @@ async def test_sb_client_incorrect_queue_conn_str_async(self, servicebus_queue_a @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() async def test_client_sas_credential_async(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): # This should "just work" to validate known-good. credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) @@ -308,32 +341,75 @@ async def test_client_sas_credential_async(self, token = (await credential.get_token(auth_uri)).token # Finally let's do it with SAS token + conn str - token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) + token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token) - client = ServiceBusClient.from_connection_string(token_conn_str) + client = ServiceBusClient.from_connection_string(token_conn_str, uamqp_transport=uamqp_transport) async with client: assert len(client._handlers) == 0 async with client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("foo")) + def generate_sas_token(uri, sas_name, sas_value, token_ttl): + """Performs the signing and encoding needed to generate a sas token from a sas key.""" + sas = sas_value.encode('utf-8') + expiry = str(int(time.time() + token_ttl)) + string_to_sign = (uri + '\n' + expiry).encode('utf-8') + signed_hmac_sha256 = hmac.HMAC(sas, string_to_sign, hashlib.sha256) + signature = url_parse_quote(base64.b64encode(signed_hmac_sha256.digest())) + return 'SharedAccessSignature sr={}&sig={}&se={}&skn={}'.format(uri, signature, expiry, sas_name) + + class CustomizedSASCredential(object): + def __init__(self, token, expiry): + """ + :param str token: The token string + :param float expiry: The epoch timestamp + """ + self.token = token + self.expiry = expiry + self.token_type = b"servicebus.windows.net:sastoken" + + async def get_token(self, *scopes, **kwargs): + """ + This method is automatically called when token is about to expire. + """ + return AccessToken(self.token, self.expiry) + + token_ttl = 5 # seconds + sas_token = generate_sas_token( + auth_uri, servicebus_namespace_key_name, servicebus_namespace_primary_key, token_ttl + ) + credential=CustomizedSASCredential(sas_token, time.time() + token_ttl) + + async with ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) as client: + sender = client.get_queue_sender(queue_name=servicebus_queue.name) + await asyncio.sleep(10) + with pytest.raises(ServiceBusAuthenticationError): + async with sender: + message = ServiceBusMessage("Single Message") + await sender.send_messages(message) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() async def test_client_credential_async(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): # This should "just work" to validate known-good. credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) async with client: assert len(client._handlers) == 0 async with client.get_queue_sender(servicebus_queue.name) as sender: @@ -341,7 +417,7 @@ async def test_client_credential_async(self, hostname = f"sb://{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) async with client: assert len(client._handlers) == 0 async with client.get_queue_sender(servicebus_queue.name) as sender: @@ -349,7 +425,7 @@ async def test_client_credential_async(self, hostname = f"https://{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) async with client: assert len(client._handlers) == 0 async with client.get_queue_sender(servicebus_queue.name) as sender: @@ -361,22 +437,26 @@ async def test_client_credential_async(self, @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() async def test_client_azure_sas_credential_async(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): # This should "just work" to validate known-good. credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" auth_uri = "sb://{}/{}".format(hostname, servicebus_queue.name) - token = (await credential.get_token(auth_uri)).token.decode() + token = (await credential.get_token(auth_uri)).token credential = AzureSasCredential(token) - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) async with client: assert len(client._handlers) == 0 async with client.get_queue_sender(servicebus_queue.name) as sender: @@ -388,17 +468,21 @@ async def test_client_azure_sas_credential_async(self, @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_client_named_key_credential_async(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_azure_named_key_credential_async(self, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" credential = AzureNamedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) async with client: async with client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("foo")) @@ -415,14 +499,16 @@ async def test_client_named_key_credential_async(self, async with client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("foo")) - async def test_backoff_fixed_retry(self): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_backoff_fixed_retry(self, uamqp_transport): client = ServiceBusClient( 'fake.host.com', 'fake_eh', - retry_mode='fixed' + retry_mode='fixed', + uamqp_transport=uamqp_transport ) # queue sender - sender = await client.get_queue_sender('fake_name') + sender = client.get_queue_sender('fake_name') backoff = client._config.retry_backoff_factor start_time = time.time() sender._backoff(retried_times=1, last_exception=Exception('fake'), abs_timeout_time=None) @@ -433,7 +519,7 @@ async def test_backoff_fixed_retry(self): assert sleep_time_fixed < backoff * (2 ** 1) # topic sender - sender = await client.get_topic_sender('fake_name') + sender = client.get_topic_sender('fake_name') backoff = client._config.retry_backoff_factor start_time = time.time() sender._backoff(retried_times=1, last_exception=Exception('fake'), abs_timeout_time=None) @@ -441,7 +527,7 @@ async def test_backoff_fixed_retry(self): assert sleep_time_fixed < backoff * (2 ** 1) # queue receiver - receiver = await client.get_queue_receiver('fake_name') + receiver = client.get_queue_receiver('fake_name') backoff = client._config.retry_backoff_factor start_time = time.time() receiver._backoff(retried_times=1, last_exception=Exception('fake'), abs_timeout_time=None) @@ -449,86 +535,102 @@ async def test_backoff_fixed_retry(self): assert sleep_time_fixed < backoff * (2 ** 1) # subscription receiver - receiver = await client.get_subscription_receiver('fake_topic', 'fake_sub') + receiver = client.get_subscription_receiver('fake_topic', 'fake_sub') backoff = client._config.retry_backoff_factor start_time = time.time() receiver._backoff(retried_times=1, last_exception=Exception('fake'), abs_timeout_time=None) sleep_time_fixed = time.time() - start_time assert sleep_time_fixed < backoff * (2 ** 1) - async def test_custom_client_id_queue_sender_async(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_custom_client_id_queue_sender_async(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" custom_id = "my_custom_id" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: queue_sender = servicebus_client.get_queue_sender(queue_name=queue_name, client_identifier=custom_id) assert queue_sender.client_identifier is not None assert queue_sender.client_identifier == custom_id - async def test_default_client_id_queue_sender(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_default_client_id_queue_sender(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: queue_sender = servicebus_client.get_queue_sender(queue_name=queue_name) assert queue_sender.client_identifier is not None assert "SBSender" in queue_sender.client_identifier - async def test_custom_client_id_queue_receiver(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_custom_client_id_queue_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" custom_id = "my_custom_id" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: queue_receiver = servicebus_client.get_queue_receiver(queue_name=queue_name, client_identifier=custom_id) assert queue_receiver.client_identifier is not None assert queue_receiver.client_identifier == custom_id - async def test_default_client_id_queue_receiver(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_default_client_id_queue_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: queue_receiver = servicebus_client.get_queue_receiver(queue_name=queue_name) assert queue_receiver.client_identifier is not None assert "SBReceiver" in queue_receiver.client_identifier - async def test_custom_client_id_topic_sender(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_custom_client_id_topic_sender(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' custom_id = "my_custom_id" topic_name = "topic_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: topic_sender = servicebus_client.get_topic_sender(topic_name=topic_name, client_identifier=custom_id) assert topic_sender.client_identifier is not None assert topic_sender.client_identifier == custom_id - async def test_default_client_id_topic_sender(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_default_client_id_topic_sender(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' topic_name = "topic_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: topic_sender = servicebus_client.get_topic_sender(topic_name=topic_name) assert topic_sender.client_identifier is not None assert "SBSender" in topic_sender.client_identifier - async def test_default_client_id_subscription_receiver(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_default_client_id_subscription_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' topic_name = "topic_name" sub_name = "sub_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: subscription_receiver = servicebus_client.get_subscription_receiver(topic_name, sub_name) assert subscription_receiver.client_identifier is not None assert "SBReceiver" in subscription_receiver.client_identifier - async def test_custom_client_id_subscription_receiver(self, **kwargs): + @pytest.mark.asyncio + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + async def test_custom_client_id_subscription_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' custom_id = "my_custom_id" topic_name = "topic_name" sub_name = "sub_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) async with servicebus_client: subscription_receiver = servicebus_client.get_subscription_receiver(topic_name, sub_name, client_identifier=custom_id) assert subscription_receiver.client_identifier is not None @@ -536,22 +638,54 @@ async def test_custom_client_id_subscription_receiver(self, **kwargs): @pytest.mark.asyncio @pytest.mark.liveTest + @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - async def test_connection_verify_exception_async(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_custom_endpoint_connection_verify_exception_async(self, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" credential = AzureNamedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) - client = ServiceBusClient(hostname, credential, connection_verify="cacert.pem") + client = ServiceBusClient(hostname, credential, connection_verify="cacert.pem", uamqp_transport=uamqp_transport) async with client: with pytest.raises(ServiceBusError): async with client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("foo")) + # Skipping on OSX uamqp - it's raising an Authentication/TimeoutError + if not uamqp_transport or not sys.platform.startswith('darwin'): + fake_addr = "fakeaddress.com:1111" + client = ServiceBusClient( + hostname, + credential, + custom_endpoint_address=fake_addr, + retry_total=0, + uamqp_transport=uamqp_transport + ) + async with client: + with pytest.raises(ServiceBusConnectionError): + async with client.get_queue_sender(servicebus_queue.name) as sender: + await sender.send_messages(ServiceBusMessage("foo")) + + client = ServiceBusClient( + hostname, + credential, + custom_endpoint_address=fake_addr, + connection_verify="cacert.pem", + retry_total=0, + uamqp_transport=uamqp_transport, + ) + async with client: + with pytest.raises(ServiceBusError): + async with client.get_queue_sender(servicebus_queue.name) as sender: + await sender.send_messages(ServiceBusMessage("foo")) diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sessions_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sessions_async.py index 0e6aa1497a702..a670b4622898f 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sessions_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sessions_async.py @@ -11,9 +11,9 @@ import pytest import time import uuid +import pickle from datetime import datetime, timedelta -from uamqp.errors import VendorLinkDetach from azure.servicebus import ( ServiceBusMessage, ServiceBusReceivedMessage, @@ -33,7 +33,7 @@ MessageAlreadySettled, AutoLockRenewTimeout ) -from devtools_testutils import AzureMgmtTestCase +from devtools_testutils import AzureMgmtRecordedTestCase from servicebus_preparer import ( CachedServiceBusNamespacePreparer, CachedServiceBusQueuePreparer, @@ -42,12 +42,14 @@ ServiceBusSubscriptionPreparer, CachedServiceBusResourceGroupPreparer ) -from utilities import get_logger, print_message +from utilities import get_logger, print_message, uamqp_transport as get_uamqp_transport, ArgPasserAsync + +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() _logger = get_logger(logging.DEBUG) -class ServiceBusAsyncSessionTests(AzureMgmtTestCase): +class TestServiceBusAsyncSession(AzureMgmtRecordedTestCase): @pytest.mark.asyncio @pytest.mark.liveTest @@ -55,9 +57,11 @@ class ServiceBusAsyncSessionTests(AzureMgmtTestCase): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_session_client_conn_str_receive_handler_peeklock(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_session_client_conn_str_receive_handler_peeklock(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -101,15 +105,22 @@ async def test_async_session_by_session_client_conn_str_receive_handler_peeklock assert count == 3 + with pytest.raises(ServiceBusError): + receiver = sb_client.get_queue_receiver(servicebus_queue.name, session_id=1) + async with receiver: + pass + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True, lock_duration='PT10S') - async def test_async_session_by_queue_client_conn_str_receive_handler_receiveanddelete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_queue_client_conn_str_receive_handler_receiveanddelete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -139,15 +150,18 @@ async def test_async_session_by_queue_client_conn_str_receive_handler_receiveand messages.append(message) assert len(messages) == 0 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_session_client_conn_str_receive_handler_with_stop(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_session_client_conn_str_receive_handler_with_stop(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -187,23 +201,28 @@ async def test_async_session_by_session_client_conn_str_receive_handler_with_sto @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_session_client_conn_str_receive_handler_with_no_session(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_session_client_conn_str_receive_handler_with_no_session(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: receiver = sb_client.get_queue_receiver(servicebus_queue.name, session_id=NEXT_AVAILABLE_SESSION, max_wait_time=10) with pytest.raises(OperationTimeoutError): await receiver._open_with_retry() + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_session_client_conn_str_receive_handler_with_inactive_session(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_session_client_conn_str_receive_handler_with_inactive_session(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) messages = [] @@ -215,15 +234,18 @@ async def test_async_session_by_session_client_conn_str_receive_handler_with_ina assert not receiver._running assert len(messages) == 0 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] session_id = str(uuid.uuid4()) @@ -253,15 +275,18 @@ async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_de await receiver.renew_message_lock(message) await receiver.complete_message(message) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] session_id = str(uuid.uuid4()) @@ -300,15 +325,18 @@ async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_de await receiver.complete_message(message) assert count == 10 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] session_id = str(uuid.uuid4()) @@ -335,15 +363,18 @@ async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_de with pytest.raises(ServiceBusError): deferred = await receiver.receive_deferred_messages(deferred_messages) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] session_id = str(uuid.uuid4()) @@ -366,15 +397,18 @@ async def test_async_session_by_servicebus_client_iter_messages_with_retrieve_de with pytest.raises(ValueError): await receiver.complete_message(message) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_fetch_next_with_retrieve_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_fetch_next_with_retrieve_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_receiver(servicebus_queue.name, session_id=session_id, max_wait_time=5, prefetch_count=10) as receiver: @@ -415,9 +449,11 @@ async def test_async_session_by_servicebus_client_fetch_next_with_retrieve_deadl @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_browse_messages_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_browse_messages_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -449,9 +485,11 @@ async def test_async_session_by_servicebus_client_browse_messages_client(self, s @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_browse_messages_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_browse_messages_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, session_id=session_id) as receiver: @@ -474,9 +512,11 @@ async def test_async_session_by_servicebus_client_browse_messages_with_receiver( @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_renew_client_locks(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_renew_client_locks(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) messages = [] @@ -510,15 +550,18 @@ async def test_async_session_by_servicebus_client_renew_client_locks(self, servi with pytest.raises(SessionLockLostError): await receiver.complete_message(messages[2]) + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True, lock_duration='PT5S') - async def test_async_session_by_conn_str_receive_handler_with_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_conn_str_receive_handler_with_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -576,16 +619,18 @@ async def lock_lost_callback(renewable, error): await renewer.close() assert len(messages) == 2 - + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True, lock_duration='PT10S') - async def test_async_session_by_conn_str_receive_handler_with_auto_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_conn_str_receive_handler_with_auto_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -671,16 +716,17 @@ async def lock_lost_callback(renewable, error): await receiver.close() assert not renewer._renewable(receiver._session) - @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_message_connection_closed(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_message_connection_closed(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) @@ -696,16 +742,17 @@ async def test_async_session_message_connection_closed(self, servicebus_namespac with pytest.raises(ValueError): await receiver.complete_message(messages[0]) - @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_message_expiry(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_message_expiry(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) @@ -742,9 +789,11 @@ async def test_async_session_message_expiry(self, servicebus_namespace_connectio @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_schedule_message(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_schedule_message(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: import uuid session_id = str(uuid.uuid4()) enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) @@ -773,16 +822,17 @@ async def test_async_session_schedule_message(self, servicebus_namespace_connect raise Exception("Failed to receive schdeduled message.") await renewer.close() - @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_schedule_multiple_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_schedule_multiple_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: import uuid session_id = str(uuid.uuid4()) enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) @@ -813,16 +863,18 @@ async def test_async_session_schedule_multiple_messages(self, servicebus_namespa else: raise Exception("Failed to receive schdeduled message.") await renewer.close() - + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_cancel_scheduled_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_cancel_scheduled_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) @@ -854,10 +906,12 @@ async def test_async_session_cancel_scheduled_messages(self, servicebus_namespac @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_session_receiver_partially_invalid_autolockrenew_mode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_session_receiver_partially_invalid_autolockrenew_mode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): session_id = str(uuid.uuid4()) async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("test_message", session_id=session_id)) @@ -872,15 +926,18 @@ async def should_not_run(*args, **kwargs): assert receiver.receive_messages() assert not failures + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_get_set_state_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_get_set_state_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -897,10 +954,24 @@ async def test_async_session_get_set_state_with_receiver(self, servicebus_namesp count += 1 state = await receiver.session.get_state() assert state == b'first_state' + assert count == 3 + + session_id = str(uuid.uuid4()) + async with sb_client.get_queue_sender(servicebus_queue.name) as sender: + for i in range(1): + message = ServiceBusMessage("Handler message no. {}".format(i), session_id=session_id) + await sender.send_messages(message) + + async with sb_client.get_queue_receiver(servicebus_queue.name, session_id=session_id, max_wait_time=10) as receiver: + assert await receiver.session.get_state(timeout=5) == None await receiver.session.set_state(None, timeout=5) + count = 0 + async for m in receiver: + assert m.session_id == session_id + count += 1 state = await receiver.session.get_state() - assert not state - assert count == 3 + assert state == None + assert count == 1 @pytest.mark.skip(reason='Requires list sessions') @pytest.mark.asyncio @@ -909,9 +980,11 @@ async def test_async_session_get_set_state_with_receiver(self, servicebus_namesp @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_list_sessions_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_list_sessions_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sessions = [] start_time = utc_now() @@ -939,9 +1012,11 @@ async def test_async_session_by_servicebus_client_list_sessions_with_receiver(se @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_list_sessions_with_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_list_sessions_with_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sessions = [] start_time = utc_now() @@ -968,7 +1043,9 @@ async def test_async_session_by_servicebus_client_list_sessions_with_client(self @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_by_servicebus_client_session_pool(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_by_servicebus_client_session_pool(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): messages = [] errors = [] @@ -989,7 +1066,7 @@ async def message_processing(sb_client): concurrent_receivers = 5 sessions = [str(uuid.uuid4()) for i in range(concurrent_receivers)] async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False, retry_total=1) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, retry_total=1, uamqp_transport=uamqp_transport) as sb_client: for session_id in sessions: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -1001,6 +1078,7 @@ async def message_processing(sb_client): assert not errors assert len(messages) == 100 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @@ -1008,10 +1086,13 @@ async def message_processing(sb_client): @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_basic_topic_subscription_send_and_receive(self, servicebus_namespace_connection_string, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_basic_topic_subscription_send_and_receive(self, uamqp_transport, *, servicebus_namespace_connection_string, servicebus_topic, servicebus_subscription, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: async with sb_client.get_topic_sender(topic_name=servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message", session_id='test_session') @@ -1036,10 +1117,12 @@ async def test_async_session_basic_topic_subscription_send_and_receive(self, ser @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_connection_failure_is_idempotent(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_connection_failure_is_idempotent(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): #Technically this validates for all senders/receivers, not just session, but since it uses session to generate a recoverable failure, putting it in here. async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False, retry_total=1) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, retry_total=1, uamqp_transport=uamqp_transport) as sb_client: # First let's just try the naive failure cases. receiver = sb_client.get_queue_receiver("THIS_IS_WRONG_ON_PURPOSE") @@ -1075,9 +1158,11 @@ async def test_async_session_connection_failure_is_idempotent(self, servicebus_n @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - async def test_async_session_non_session_send_to_session_queue_should_fail(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_async_session_non_session_send_to_session_queue_should_fail(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): async with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: async with sb_client.get_queue_sender(servicebus_queue.name) as sender: with pytest.raises(ServiceBusError): diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py index 0919d457d4c68..d3dcbc8956d00 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py @@ -17,7 +17,7 @@ from azure.servicebus.exceptions import ServiceBusError from azure.servicebus._common.constants import ServiceBusSubQueue -from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer from servicebus_preparer import ( CachedServiceBusNamespacePreparer, CachedServiceBusTopicPreparer, @@ -27,12 +27,15 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX ) -from utilities import get_logger, print_message +from utilities import get_logger, print_message, uamqp_transport as get_uamqp_transport, ArgPasserAsync + +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() _logger = get_logger(logging.DEBUG) -class ServiceBusSubscriptionAsyncTests(AzureMgmtTestCase): +class TestServiceBusSubscriptionAsync(AzureMgmtRecordedTestCase): + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @@ -40,11 +43,14 @@ class ServiceBusSubscriptionAsyncTests(AzureMgmtTestCase): @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest') - async def test_subscription_by_subscription_client_conn_str_receive_basic(self, servicebus_namespace_connection_string, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_subscription_by_subscription_client_conn_str_receive_basic(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: async with sb_client.get_topic_sender(topic_name=servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message") @@ -66,15 +72,13 @@ async def test_subscription_by_subscription_client_conn_str_receive_basic(self, with pytest.raises(ValueError): await receiver.receive_messages(max_wait_time=-1) - with pytest.raises(ValueError): - await receiver._get_streaming_message_iter(max_wait_time=0) - count = 0 async for message in receiver: count += 1 await receiver.complete_message(message) assert count == 1 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @@ -82,7 +86,9 @@ async def test_subscription_by_subscription_client_conn_str_receive_basic(self, @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest') - async def test_subscription_by_sas_token_credential_conn_str_send_basic(self, servicebus_namespace, servicebus_namespace_key_name, servicebus_namespace_primary_key, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_subscription_by_sas_token_credential_conn_str_send_basic(self, uamqp_transport, *, servicebus_namespace=None, servicebus_namespace_key_name=None, servicebus_namespace_primary_key=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" async with ServiceBusClient( fully_qualified_namespace=fully_qualified_namespace, @@ -90,7 +96,8 @@ async def test_subscription_by_sas_token_credential_conn_str_send_basic(self, se policy=servicebus_namespace_key_name, key=servicebus_namespace_primary_key ), - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: async with sb_client.get_topic_sender(topic_name=servicebus_topic.name) as sender: @@ -108,6 +115,7 @@ async def test_subscription_by_sas_token_credential_conn_str_send_basic(self, se await receiver.complete_message(message) assert count == 1 + @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @@ -115,11 +123,13 @@ async def test_subscription_by_sas_token_credential_conn_str_send_basic(self, se @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest') - async def test_topic_by_servicebus_client_receive_batch_with_deadletter(self, servicebus_namespace_connection_string, servicebus_topic, servicebus_subscription, **kwargs): - + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_topic_by_servicebus_client_receive_batch_with_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: async with sb_client.get_subscription_receiver( diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py index 489096335d26f..6e0aab8b31542 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py @@ -12,7 +12,7 @@ import time from datetime import datetime, timedelta -from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer from azure.servicebus.aio import ServiceBusClient from azure.servicebus.aio._base_handler_async import ServiceBusSharedKeyCredential @@ -25,23 +25,28 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX ) -from utilities import get_logger, print_message +from utilities import get_logger, print_message, uamqp_transport as get_uamqp_transport, ArgPasserAsync + +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() _logger = get_logger(logging.DEBUG) -class ServiceBusTopicsAsyncTests(AzureMgmtTestCase): +class TestServiceBusTopicsAsync(AzureMgmtRecordedTestCase): @pytest.mark.asyncio @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusTopicPreparer(name_prefix='servicebustest') - async def test_topic_by_servicebus_client_conn_str_send_basic(self, servicebus_namespace_connection_string, servicebus_topic, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_topic_by_servicebus_client_conn_str_send_basic(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_topic=None, **kwargs): async with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: async with sb_client.get_topic_sender(servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message") @@ -53,7 +58,9 @@ async def test_topic_by_servicebus_client_conn_str_send_basic(self, servicebus_n @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusTopicPreparer(name_prefix='servicebustest') - async def test_topic_by_sas_token_credential_conn_str_send_basic(self, servicebus_namespace, servicebus_namespace_key_name, servicebus_namespace_primary_key, servicebus_topic, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasserAsync() + async def test_topic_by_sas_token_credential_conn_str_send_basic(self, uamqp_transport, *, servicebus_namespace=None, servicebus_namespace_key_name=None, servicebus_namespace_primary_key=None, servicebus_topic=None, **kwargs): fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" async with ServiceBusClient( fully_qualified_namespace=fully_qualified_namespace, @@ -61,7 +68,8 @@ async def test_topic_by_sas_token_credential_conn_str_send_basic(self, servicebu policy=servicebus_namespace_key_name, key=servicebus_namespace_primary_key ), - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: async with sb_client.get_topic_sender(servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message") diff --git a/sdk/servicebus/azure-servicebus/tests/livetest/test_errors.py b/sdk/servicebus/azure-servicebus/tests/livetest/test_errors.py deleted file mode 100644 index 083e62b4a3107..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/livetest/test_errors.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging - -from uamqp import errors as AMQPErrors, constants as AMQPConstants -from azure.servicebus.exceptions import ( - _create_servicebus_exception, - ServiceBusConnectionError, - ServiceBusError -) - - -def test_link_idle_timeout(): - logger = logging.getLogger("testlogger") - amqp_error = AMQPErrors.LinkDetach(AMQPConstants.ErrorCodes.LinkDetachForced, description="Details: AmqpMessageConsumer.IdleTimerExpired: Idle timeout: 00:10:00.") - sb_error = _create_servicebus_exception(logger, amqp_error) - assert isinstance(sb_error, ServiceBusConnectionError) - assert sb_error._retryable - assert sb_error._shutdown_handler - - -def test_unknown_connection_error(): - logger = logging.getLogger("testlogger") - amqp_error = AMQPErrors.AMQPConnectionError(AMQPConstants.ErrorCodes.UnknownError) - sb_error = _create_servicebus_exception(logger, amqp_error) - assert isinstance(sb_error,ServiceBusConnectionError) - assert sb_error._retryable - assert sb_error._shutdown_handler - - amqp_error = AMQPErrors.AMQPError(AMQPConstants.ErrorCodes.UnknownError) - sb_error = _create_servicebus_exception(logger, amqp_error) - assert not isinstance(sb_error,ServiceBusConnectionError) - assert isinstance(sb_error,ServiceBusError) - assert not sb_error._retryable - assert sb_error._shutdown_handler - -def test_internal_server_error(): - logger = logging.getLogger("testlogger") - amqp_error = AMQPErrors.LinkDetach( - description="The service was unable to process the request; please retry the operation.", - condition=AMQPConstants.ErrorCodes.InternalServerError - ) - sb_error = _create_servicebus_exception(logger, amqp_error) - assert isinstance(sb_error, ServiceBusError) - assert sb_error._retryable - assert sb_error._shutdown_handler diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/__init__.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/_test_base.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/_test_base.py deleted file mode 100644 index 654b7b40b532c..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/_test_base.py +++ /dev/null @@ -1,153 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -import uuid -from urllib.parse import urlparse - -from azure_devtools.perfstress_tests import PerfStressTest, get_random_bytes - -from azure.servicebus import ServiceBusClient, ReceiveSettleMode, Message -from azure.servicebus.aio import ServiceBusClient as AsyncServiceBusClient -from azure.servicebus.control_client import ServiceBusService -from azure.servicebus.control_client.models import Queue - -MAX_QUEUE_SIZE = 40960 - - -def parse_connection_string(conn_str): - conn_settings = [s.split("=", 1) for s in conn_str.split(";")] - conn_settings = dict(conn_settings) - shared_access_key = conn_settings.get('SharedAccessKey') - shared_access_key_name = conn_settings.get('SharedAccessKeyName') - endpoint = conn_settings.get('Endpoint') - parsed = urlparse(endpoint.rstrip('/')) - namespace = parsed.netloc.strip().split('.')[0] - return { - 'namespace': namespace, - 'endpoint': endpoint, - 'entity_path': conn_settings.get('EntityPath'), - 'shared_access_key_name': shared_access_key_name, - 'shared_access_key': shared_access_key - } - - -class _ServiceTest(PerfStressTest): - service_client = None - async_service_client = None - - def __init__(self, arguments): - super().__init__(arguments) - - connection_string = self.get_from_env("AZURE_SERVICEBUS_CONNECTION_STRING") - if self.args.no_client_share: - self.service_client = ServiceBusClient.from_connection_string(connection_string) - self.async_service_client = AsyncServiceBusClient.from_connection_string(connection_string) - else: - if not _ServiceTest.service_client: - _ServiceTest.service_client = ServiceBusClient.from_connection_string(connection_string) - _ServiceTest.async_service_client = AsyncServiceBusClient.from_connection_string(connection_string) - self.service_client = _ServiceTest.service_client - self.async_service_client =_ServiceTest.async_service_client - - @staticmethod - def add_arguments(parser): - super(_ServiceTest, _ServiceTest).add_arguments(parser) - parser.add_argument('--message-size', nargs='?', type=int, help='Size of a single message. Defaults to 100 bytes', default=100) - parser.add_argument('--no-client-share', action='store_true', help='Create one ServiceClient per test instance. Default is to share a single ServiceClient.', default=False) - parser.add_argument('--num-messages', nargs='?', type=int, help='Number of messages to send or receive. Defaults to 100', default=100) - - -class _QueueTest(_ServiceTest): - queue_name = "perfstress-" + str(uuid.uuid4()) - queue_client = None - async_queue_client = None - - def __init__(self, arguments): - super().__init__(arguments) - connection_string = self.get_from_env("AZURE_SERVICEBUS_CONNECTION_STRING") - connection_props = parse_connection_string(connection_string) - self.mgmt_client = ServiceBusService( - service_namespace=connection_props['namespace'], - shared_access_key_name=connection_props['shared_access_key_name'], - shared_access_key_value=connection_props['shared_access_key']) - - async def global_setup(self): - await super().global_setup() - queue = Queue(max_size_in_megabytes=MAX_QUEUE_SIZE) - self.mgmt_client.create_queue(self.queue_name, queue=queue) - - async def setup(self): - await super().setup() - # In T1, these operations check for the existance of the queue - # so must be created during setup, rather than in the constructor. - self.queue_client = self.service_client.get_queue(self.queue_name) - self.async_queue_client = self.async_service_client.get_queue(self.queue_name) - - async def global_cleanup(self): - self.mgmt_client.delete_queue(self.queue_name) - await super().global_cleanup() - - -class _SendTest(_QueueTest): - sender = None - async_sender = None - - async def setup(self): - await super().setup() - self.sender = self.queue_client.get_sender() - self.async_sender = self.async_queue_client.get_sender() - self.sender.open() - await self.async_sender.open() - - async def close(self): - self.sender.close() - await self.async_sender.close() - await super().close() - - -class _ReceiveTest(_QueueTest): - receiver = None - async_receiver = None - - async def global_setup(self): - await super().global_setup() - await self._preload_queue() - - async def setup(self): - await super().setup() - mode = ReceiveSettleMode.PeekLock if self.args.peeklock else ReceiveSettleMode.ReceiveAndDelete - self.receiver = self.queue_client.get_receiver( - mode=mode, - prefetch=self.args.num_messages, - idle_timeout=self.args.max_wait_time) - self.async_receiver = self.async_queue_client.get_receiver( - mode=mode, - prefetch=self.args.num_messages, - idle_timeout=self.args.max_wait_time) - self.receiver.open() - await self.async_receiver.open() - - async def _preload_queue(self): - data = get_random_bytes(self.args.message_size) - async_queue_client = self.async_service_client.get_queue(self.queue_name) - async with async_queue_client.get_sender() as sender: - for i in range(self.args.preload): - sender.queue_message(Message(data)) - if i % 1000 == 0: - print("Loaded {} messages".format(i)) - await sender.send_pending_messages() - await sender.send_pending_messages() - - async def close(self): - self.receiver.close() - await self.async_receiver.close() - await super().close() - - @staticmethod - def add_arguments(parser): - super(_ReceiveTest, _ReceiveTest).add_arguments(parser) - parser.add_argument('--peeklock', action='store_true', help='Receive using PeekLock mode and message settlement.', default=False) - parser.add_argument('--max-wait-time', nargs='?', type=int, help='Max time to wait for messages before closing. Defaults to 0.', default=0) - parser.add_argument('--preload', nargs='?', type=int, help='Number of messages to preload. Default is 10000.', default=10000) diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/receive_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/receive_message_batch.py deleted file mode 100644 index 171786f81144b..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/receive_message_batch.py +++ /dev/null @@ -1,27 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -import asyncio - -from ._test_base import _ReceiveTest - - -class LegacyReceiveMessageBatchTest(_ReceiveTest): - def run_sync(self): - count = 0 - while count < self.args.num_messages: - batch = self.receiver.fetch_next(max_batch_size=self.args.num_messages - count) - if self.args.peeklock: - for msg in batch: - msg.complete() - count += len(batch) - - async def run_async(self): - count = 0 - while count < self.args.num_messages: - batch = await self.async_receiver.fetch_next(max_batch_size=self.args.num_messages - count) - if self.args.peeklock: - await asyncio.gather(*[m.complete() for m in batch]) - count += len(batch) diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message.py deleted file mode 100644 index 0e4c9fb79642a..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message.py +++ /dev/null @@ -1,25 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -from ._test_base import _SendTest - -from azure_devtools.perfstress_tests import get_random_bytes - -from azure.servicebus import Message -from azure.servicebus.aio import Message as AsyncMessage - - -class LegacySendMessageTest(_SendTest): - def __init__(self, arguments): - super().__init__(arguments) - self.data = get_random_bytes(self.args.message_size) - - def run_sync(self): - message = Message(self.data) - self.sender.send(message) - - async def run_async(self): - message = AsyncMessage(self.data) - await self.async_sender.send(message) diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message_batch.py deleted file mode 100644 index fc73cf0bb2d7f..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/send_message_batch.py +++ /dev/null @@ -1,26 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -from ._test_base import _SendTest - -from azure_devtools.perfstress_tests import get_random_bytes - -from azure.servicebus import BatchMessage - - -class LegacySendMessageBatchTest(_SendTest): - def __init__(self, arguments): - super().__init__(arguments) - self.data = get_random_bytes(self.args.message_size) - - def run_sync(self): - messages = (self.data for _ in range(self.args.num_messages)) - batch = BatchMessage(messages) - self.sender.send(batch) - - async def run_async(self): - messages = (self.data for _ in range(self.args.num_messages)) - batch = BatchMessage(messages) - await self.async_sender.send(batch) diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/t1_test_requirements.txt b/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/t1_test_requirements.txt deleted file mode 100644 index e07a582d0f192..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/t1_test_requirements.txt +++ /dev/null @@ -1 +0,0 @@ -azure-servicebus>=0.5,<1.0.0 diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/_test_base.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/_test_base.py index f0ac79bf020ed..a822c0e691a88 100644 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/_test_base.py +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/_test_base.py @@ -5,124 +5,272 @@ import uuid -from azure_devtools.perfstress_tests import PerfStressTest, get_random_bytes +from azure_devtools.perfstress_tests import PerfStressTest, get_random_bytes, BatchPerfTest -from azure.servicebus import ServiceBusClient, ServiceBusReceiveMode, ServiceBusMessage +from azure.servicebus import ServiceBusClient, ServiceBusReceiveMode, ServiceBusMessage, TransportType from azure.servicebus.aio import ServiceBusClient as AsyncServiceBusClient from azure.servicebus.aio.management import ServiceBusAdministrationClient -MAX_QUEUE_SIZE = 40960 +class _ReceiveTest(): + def setup_servicebus_clients(self, transport_type, peeklock, num_messages, max_wait_time, uamqp_tranport) -> None: + self.connection_string=self.get_from_env("AZURE_SERVICEBUS_CONNECTION_STRING") + transport_type=TransportType.AmqpOverWebsocket if transport_type == 1 else TransportType.Amqp + mode=ServiceBusReceiveMode.PEEK_LOCK if peeklock else ServiceBusReceiveMode.RECEIVE_AND_DELETE + uamqp_tranport = uamqp_tranport if uamqp_tranport else False + self.servicebus_client=ServiceBusClient.from_connection_string( + self.connection_string, + receive_mode=mode, + prefetch_count=num_messages, + max_wait_time=max_wait_time or None, + transport_type=transport_type, + uamqp_transport=uamqp_tranport, + ) + self.async_servicebus_client=AsyncServiceBusClient.from_connection_string( + self.connection_string, + receive_mode=mode, + prefetch_count=num_messages, + max_wait_time=max_wait_time or None, + transport_type=transport_type, + uamqp_transport=uamqp_tranport, + ) + async def close_clients(self) -> None: + self.servicebus_client.close + await self.async_servicebus_client.close() + await super().close() + async def _preload_topic(self) -> None: + data=get_random_bytes(self.args.message_size) -class _ServiceTest(PerfStressTest): - service_client = None - async_service_client = None + current_topic_message_count = 0 - def __init__(self, arguments): - super().__init__(arguments) + async with ServiceBusAdministrationClient.from_connection_string(self.connection_string) as admin_client: + topic_properties = await admin_client.get_topic_runtime_properties(self.topic_name) + current_topic_message_count = topic_properties.scheduled_message_count + - connection_string = self.get_from_env("AZURE_SERVICEBUS_CONNECTION_STRING") - if self.args.no_client_share: - self.service_client = ServiceBusClient.from_connection_string(connection_string) - self.async_service_client = AsyncServiceBusClient.from_connection_string(connection_string) - else: - if not _ServiceTest.service_client: - _ServiceTest.service_client = ServiceBusClient.from_connection_string(connection_string) - _ServiceTest.async_service_client = AsyncServiceBusClient.from_connection_string(connection_string) - self.service_client = _ServiceTest.service_client - self.async_service_client =_ServiceTest.async_service_client + print(f"The current topic {self.topic_name} has {current_topic_message_count} messages") + async with self.async_servicebus_client.get_topic_sender(self.topic_name) as sender: + batch = await sender.create_message_batch() - async def close(self): - self.service_client.close() - await self.async_service_client.close() - await super().close() + for i in range(current_topic_message_count, self.args.preload): + try: + batch.add_message(ServiceBusMessage(data)) + except ValueError: + await sender.send_messages(batch) + print(f"Loaded {i} messages") + batch = await sender.create_message_batch() + batch.add_message(ServiceBusMessage(data)) - @staticmethod - def add_arguments(parser): - super(_ServiceTest, _ServiceTest).add_arguments(parser) - parser.add_argument('--message-size', nargs='?', type=int, help='Size of a single message. Defaults to 100 bytes', default=100) - parser.add_argument('--no-client-share', action='store_true', help='Create one ServiceClient per test instance. Default is to share a single ServiceClient.', default=False) - parser.add_argument('--num-messages', nargs='?', type=int, help='Number of messages to send or receive. Defaults to 100', default=100) + if len(batch): + await sender.send_messages(batch) + + async def _preload_queue(self) -> None: + data=get_random_bytes(self.args.message_size) + current_queue_message_count = 0 -class _QueueTest(_ServiceTest): - queue_name = "perfstress-" + str(uuid.uuid4()) + async with ServiceBusAdministrationClient.from_connection_string(self.connection_string) as admin_client: + queue_properties = await admin_client.get_queue_runtime_properties(self.queue_name) + current_queue_message_count = queue_properties.active_message_count + - def __init__(self, arguments): - super().__init__(arguments) - connection_string = self.get_from_env("AZURE_SERVICEBUS_CONNECTION_STRING") - self.async_mgmt_client = ServiceBusAdministrationClient.from_connection_string(connection_string) + print(f"The current queue {self.queue_name} has {current_queue_message_count} messages") - async def global_setup(self): - await super().global_setup() - await self.async_mgmt_client.create_queue(self.queue_name, max_size_in_megabytes=MAX_QUEUE_SIZE) + async with self.async_servicebus_client.get_queue_sender(self.queue_name) as sender: + batch = await sender.create_message_batch() - async def global_cleanup(self): - await self.async_mgmt_client.delete_queue(self.queue_name) - await super().global_cleanup() + for i in range(current_queue_message_count, self.args.preload): + try: + batch.add_message(ServiceBusMessage(data)) + except ValueError: + await sender.send_messages(batch) + print(f"Loaded {i} messages") + batch = await sender.create_message_batch() + batch.add_message(ServiceBusMessage(data)) + + if len(batch): + await sender.send_messages(batch) + + @staticmethod + def add_arguments(parser) -> None: + parser.add_argument('--message-size', nargs='?', type=int, help='Size of a single message. Defaults to 100 bytes', default=100) + parser.add_argument('--num-messages', nargs='?', type=int, help='Maximum number of messages to receive. Defaults to 100', default=100) + parser.add_argument('--peeklock', action='store_true', help='Receive using PeekLock mode and message settlement.', default=False) + parser.add_argument('--uamqp-transport', action="store_true", help="Switch to use uamqp transport. Default is False (pyamqp).", default=False) + parser.add_argument('--transport-type', nargs='?', type=int, help="Use Amqp (0) or Websocket (1) transport type. Default is Amqp.", default=0) + parser.add_argument('--max-wait-time', nargs='?', type=int, help='Max time to wait for messages before closing. Defaults to 0.', default=0) + parser.add_argument('--preload', nargs='?', type=int, help='Number of messages to preload. Default is 10000.', default=10000) - async def close(self): - await self.async_mgmt_client.close() +class _QueueReceiveTest(_ReceiveTest, PerfStressTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + self.setup_servicebus_clients( + arguments.transport_type, + arguments.peeklock, + arguments.num_messages, + arguments.max_wait_time, + arguments.uamqp_transport + ) + self.queue_name=self.get_from_env('AZURE_SERVICEBUS_QUEUE_NAME') + + self.receiver=self.servicebus_client.get_queue_receiver(self.queue_name) + self.async_receiver=self.async_servicebus_client.get_queue_receiver(self.queue_name) + + async def global_setup(self) -> None: + await super().global_setup() + await self._preload_queue() + + async def close(self) -> None: + self.receiver.close() + await self.async_receiver.close() + await self.close_clients() await super().close() + + -class _SendTest(_QueueTest): - def __init__(self, arguments): +class _SubscriptionReceiveTest(_ReceiveTest, PerfStressTest): + def __init__(self, arguments) -> None: super().__init__(arguments) - connection_string = self.get_from_env("AZURE_SERVICEBUS_CONNECTION_STRING") - self.async_mgmt_client = ServiceBusAdministrationClient.from_connection_string(connection_string) - self.sender = self.service_client.get_queue_sender(self.queue_name) - self.async_sender = self.async_service_client.get_queue_sender(self.queue_name) + self.setup_servicebus_clients( + arguments.transport_type, + arguments.peeklock, + arguments.num_messages, + arguments.max_wait_time, + arguments.uamqp_transport + ) + self.topic_name=self.get_from_env('AZURE_SERVICEBUS_TOPIC_NAME') + self.subscription_name=self.get_from_env('AZURE_SERVICEBUS_SUBSCRIPTION_NAME') + + self.receiver=self.servicebus_client.get_subscription_receiver(topic_name=self.topic_name, subscription_name=self.subscription_name) + self.async_receiver=self.async_servicebus_client.get_subscription_receiver(topic_name=self.topic_name, subscription_name=self.subscription_name) - async def close(self): - self.sender.close() - await self.async_sender.close() + async def global_setup(self) -> None: + await super().global_setup() + await self._preload_topic() + + async def close(self) -> None: + self.receiver.close() + await self.async_receiver.close() + await self.close_clients() await super().close() -class _ReceiveTest(_QueueTest): - def __init__(self, arguments): +class _QueueReceiveBatchTest(_ReceiveTest, BatchPerfTest): + def __init__(self, arguments) -> None: super().__init__(arguments) - mode = ServiceBusReceiveMode.PEEK_LOCK if self.args.peeklock else ServiceBusReceiveMode.RECEIVE_AND_DELETE - self.receiver = self.service_client.get_queue_receiver( - queue_name=self.queue_name, - receive_mode=mode, - prefetch_count=self.args.num_messages, - max_wait_time=self.args.max_wait_time or None) - self.async_receiver = self.async_service_client.get_queue_receiver( - queue_name=self.queue_name, - receive_mode=mode, - prefetch_count=self.args.num_messages, - max_wait_time=self.args.max_wait_time or None) + self.setup_servicebus_clients( + arguments.transport_type, + arguments.peeklock, + arguments.num_messages, + arguments.max_wait_time, + arguments.uamqp_transport + ) + self.queue_name=self.get_from_env('AZURE_SERVICEBUS_QUEUE_NAME') + self.receiver=self.servicebus_client.get_queue_receiver(self.queue_name) + self.async_receiver=self.async_servicebus_client.get_queue_receiver(self.queue_name) + + async def global_setup(self) -> None: + await super().global_setup() + await self._preload_queue() + + async def close(self) -> None: + self.receiver.close() + await self.async_receiver.close() + await self.close_clients() + await super().close() + - async def _preload_queue(self): - data = get_random_bytes(self.args.message_size) - async with self.async_service_client.get_queue_sender(self.queue_name) as sender: - batch = await sender.create_message_batch() - for i in range(self.args.preload): - try: - batch.add_message(ServiceBusMessage(data)) - except ValueError: - # Batch full - await sender.send_messages(batch) - print("Loaded {} messages".format(i)) - batch = await sender.create_message_batch() - batch.add_message(ServiceBusMessage(data)) - await sender.send_messages(batch) +class _SubscriptionReceiveBatchTest(_ReceiveTest, BatchPerfTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + self.setup_servicebus_clients( + arguments.transport_type, + arguments.peeklock, + arguments.num_messages, + arguments.max_wait_time, + arguments.uamqp_transport + ) + + self.topic_name=self.get_from_env('AZURE_SERVICEBUS_TOPIC_NAME') + self.subscription_name=self.get_from_env('AZURE_SERVICEBUS_SUBSCRIPTION_NAME') - async def global_setup(self): + self.receiver=self.servicebus_client.get_subscription_receiver(topic_name=self.topic_name, subscription_name=self.subscription_name) + self.async_receiver=self.async_servicebus_client.get_subscription_receiver(topic_name=self.topic_name, subscription_name=self.subscription_name) + + async def global_setup(self) -> None: await super().global_setup() - await self._preload_queue() + await self._preload_topic() - async def close(self): + async def close(self) -> None: self.receiver.close() await self.async_receiver.close() + await self.close_clients() await super().close() +class _SendTest(BatchPerfTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + transport_type=TransportType.AmqpOverWebsocket if arguments.transport_type == 1 else TransportType.Amqp + + self.connection_string=self.get_from_env("AZURE_SERVICEBUS_CONNECTION_STRING") + self.service_client=ServiceBusClient.from_connection_string( + self.connection_string, + transport_type=transport_type, + uamqp_transport=arguments.uamqp_transport, + ) + self.async_service_client=AsyncServiceBusClient.from_connection_string( + self.connection_string, + transport_type=transport_type, + uamqp_transport=arguments.uamqp_transport, + ) + async def close(self) -> None: + self.service_client.close() + await self.async_service_client.close() + await super().close() + @staticmethod - def add_arguments(parser): - super(_ReceiveTest, _ReceiveTest).add_arguments(parser) - parser.add_argument('--peeklock', action='store_true', help='Receive using PeekLock mode and message settlement.', default=False) - parser.add_argument('--max-wait-time', nargs='?', type=int, help='Max time to wait for messages before closing. Defaults to 0.', default=0) - parser.add_argument('--preload', nargs='?', type=int, help='Number of messages to preload. Default is 10000.', default=10000) + def add_arguments(parser) -> None: + parser.add_argument('--message-size', nargs='?', type=int, help='Size of a single message. Defaults to 100 bytes', default=100) + parser.add_argument('--batch-size', nargs='?', type=int, help='Size of a single batch message. Defaults to 100 messages', default=100) + parser.add_argument('--uamqp-transport', action="store_true", help="Switch to use uamqp transport. Default is False (pyamqp).", default=False) + parser.add_argument('--transport-type', nargs='?', type=int, help="Use Amqp (0) or Websocket (1) transport type. Default is Amqp.", default=0) + + +class _SendQueueTest(_SendTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + + self.queue_name=self.get_from_env('AZURE_SERVICEBUS_QUEUE_NAME') + self.sender=self.service_client.get_queue_sender(self.queue_name) + self.async_sender=self.async_service_client.get_queue_sender(self.queue_name) + + async def setup(self) -> None: + await super().setup() + self.sender.create_message_batch() + await self.async_sender.create_message_batch() + + async def close(self) -> None: + self.sender.close() + await self.async_sender.close() + await super().close() + + +class _SendTopicTest(_SendTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + + self.topic_name=self.get_from_env('AZURE_SERVICEBUS_TOPIC_NAME') + self.sender=self.service_client.get_topic_sender(self.topic_name) + self.async_sender=self.async_service_client.get_topic_sender(self.topic_name) + + async def setup(self) -> None: + await super().setup() + self.sender.create_message_batch() + await self.async_sender.create_message_batch() + + async def close(self) -> None: + self.sender.close() + await self.async_sender.close() + await super().close() \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_message_batch.py deleted file mode 100644 index e46eb331bb163..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_message_batch.py +++ /dev/null @@ -1,31 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -import asyncio - -from ._test_base import _ReceiveTest - - -class ReceiveMessageBatchTest(_ReceiveTest): - def run_sync(self): - count = 0 - while count < self.args.num_messages: - batch = self.receiver.receive_messages( - max_message_count=self.args.num_messages - count, - max_wait_time=self.args.max_wait_time or None) - if self.args.peeklock: - for msg in batch: - self.receiver.complete_message(msg) - count += len(batch) - - async def run_async(self): - count = 0 - while count < self.args.num_messages: - batch = await self.async_receiver.receive_messages( - max_message_count=self.args.num_messages - count, - max_wait_time=self.args.max_wait_time or None) - if self.args.peeklock: - await asyncio.gather(*[self.async_receiver.complete_message(m) for m in batch]) - count += len(batch) diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_queue_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_queue_message_batch.py new file mode 100644 index 0000000000000..4c444713af21b --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_queue_message_batch.py @@ -0,0 +1,29 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import asyncio + +from ._test_base import _QueueReceiveBatchTest + + +class ReceiveQueueMessageBatchTest(_QueueReceiveBatchTest): + def run_batch_sync(self) -> None: + batch = self.receiver.receive_messages( + max_message_count=self.args.num_messages, + max_wait_time=self.args.max_wait_time or None + ) + if self.args.peeklock: + for msg in batch: + self.receiver.complete_message(msg) + return len(batch) + + async def run_batch_async(self) -> None: + batch = await self.async_receiver.receive_messages( + max_message_count=self.args.num_messages, + max_wait_time=self.args.max_wait_time or None + ) + if self.args.peeklock: + await asyncio.gather(*[self.async_receiver.complete_message(m) for m in batch]) + return len(batch) \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_message_stream.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_queue_message_stream.py similarity index 85% rename from sdk/servicebus/azure-servicebus/tests/perf_tests/receive_message_stream.py rename to sdk/servicebus/azure-servicebus/tests/perf_tests/receive_queue_message_stream.py index f9ef6473481b7..8c291a92e1366 100644 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_message_stream.py +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_queue_message_stream.py @@ -5,11 +5,11 @@ import asyncio -from ._test_base import _ReceiveTest +from ._test_base import _QueueReceiveTest -class ReceiveMessageStreamTest(_ReceiveTest): - def run_sync(self): +class ReceiveQueueMessageStreamTest(_QueueReceiveTest): + def run_sync(self) -> None: count = 0 if self.args.peeklock: for msg in self.receiver: @@ -23,7 +23,7 @@ def run_sync(self): break count += 1 - async def run_async(self): + async def run_async(self) -> None: count = 0 if self.args.peeklock: async for msg in self.async_receiver: diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_subscription_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_subscription_message_batch.py new file mode 100644 index 0000000000000..58608d40aef02 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_subscription_message_batch.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import asyncio + +from ._test_base import _SubscriptionReceiveBatchTest + + +class ReceiveSubscriptionMessageBatchTest(_SubscriptionReceiveBatchTest): + def run_batch_sync(self) -> None: + batch = self.receiver.receive_messages( + max_message_count=self.args.num_messages, + max_wait_time=self.args.max_wait_time or None) + if self.args.peeklock: + for msg in batch: + self.receiver.complete_message(msg) + return len(batch) + + async def run_batch_async(self) -> None: + batch = await self.async_receiver.receive_messages( + max_message_count=self.args.num_messages, + max_wait_time=self.args.max_wait_time or None) + if self.args.peeklock: + await asyncio.gather(*[self.async_receiver.complete_message(m) for m in batch]) + return len(batch) \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/receive_message_stream.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_subscription_message_stream.py similarity index 76% rename from sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/receive_message_stream.py rename to sdk/servicebus/azure-servicebus/tests/perf_tests/receive_subscription_message_stream.py index f740302e3c91a..7c3ef50759364 100644 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/T1_legacy_tests/receive_message_stream.py +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/receive_subscription_message_stream.py @@ -5,34 +5,34 @@ import asyncio -from ._test_base import _ReceiveTest +from ._test_base import _SubscriptionReceiveTest -class LegacyReceiveMessageStreamTest(_ReceiveTest): - def run_sync(self): +class ReceiveSubscriptionMessageStreamTest(_SubscriptionReceiveTest): + def run_sync(self) -> None: count = 0 if self.args.peeklock: for msg in self.receiver: if count >= self.args.num_messages: break count += 1 - msg.complete() + self.receiver.complete_message(msg) else: for msg in self.receiver: if count >= self.args.num_messages: break count += 1 - async def run_async(self): + async def run_async(self) -> None: count = 0 if self.args.peeklock: async for msg in self.async_receiver: if count >= self.args.num_messages: break count += 1 - await msg.complete() + await self.async_receiver.complete_message(msg) else: async for msg in self.async_receiver: if count >= self.args.num_messages: break - count += 1 + count += 1 \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_message.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_message.py deleted file mode 100644 index 03887562c5a2e..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_message.py +++ /dev/null @@ -1,23 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -from ._test_base import _SendTest - -from azure_devtools.perfstress_tests import get_random_bytes - -from azure.servicebus import ServiceBusMessage - -class SendMessageTest(_SendTest): - def __init__(self, arguments): - super().__init__(arguments) - self.data = get_random_bytes(self.args.message_size) - - def run_sync(self): - message = ServiceBusMessage(self.data) - self.sender.send_messages(message) - - async def run_async(self): - message = ServiceBusMessage(self.data) - await self.async_sender.send_messages(message) diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_message_batch.py deleted file mode 100644 index 78bb0bf8f6695..0000000000000 --- a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_message_batch.py +++ /dev/null @@ -1,40 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -from ._test_base import _SendTest - -from azure_devtools.perfstress_tests import get_random_bytes - -from azure.servicebus import ServiceBusMessage - - -class SendMessageBatchTest(_SendTest): - def __init__(self, arguments): - super().__init__(arguments) - self.data = get_random_bytes(self.args.message_size) - - def run_sync(self): - batch = self.sender.create_message_batch() - for i in range(self.args.num_messages): - try: - batch.add_message(ServiceBusMessage(self.data)) - except ValueError: - # Batch full - self.sender.send_messages(batch) - batch = self.sender.create_message_batch() - batch.add_message(ServiceBusMessage(self.data)) - self.sender.send_messages(batch) - - async def run_async(self): - batch = await self.async_sender.create_message_batch() - for i in range(self.args.num_messages): - try: - batch.add_message(ServiceBusMessage(self.data)) - except ValueError: - # Batch full - await self.async_sender.send_messages(batch) - batch = await self.async_sender.create_message_batch() - batch.add_message(ServiceBusMessage(self.data)) - await self.async_sender.send_messages(batch) diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message.py new file mode 100644 index 0000000000000..dca88c9ca6bdb --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from ._test_base import _SendQueueTest + +from azure_devtools.perfstress_tests import get_random_bytes + +from azure.servicebus import ServiceBusMessage + +class SendQueueMessageTest(_SendQueueTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + self.data = get_random_bytes(self.args.message_size) + + def run_batch_sync(self) -> int: + if self.args.batch_size > 1: + self.sender.send_messages( + [ServiceBusMessage(self.data) for _ in range(self.args.batch_size)] + ) + else: + self.sender.send_messages(ServiceBusMessage(self.data)) + + return self.args.batch_size + + async def run_batch_async(self) -> int: + if self.args.batch_size > 1: + await self.async_sender.send_messages( + [ServiceBusMessage(self.data) for _ in range(self.args.batch_size)] + ) + else: + await self.async_sender.send_messages(ServiceBusMessage(self.data)) + + return self.args.batch_size diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message_batch.py new file mode 100644 index 0000000000000..1603b5cf2f63b --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_queue_message_batch.py @@ -0,0 +1,31 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from ._test_base import _SendQueueTest + +from azure_devtools.perfstress_tests import get_random_bytes + +from azure.servicebus import ServiceBusMessage + + +class SendQueueMessageBatchTest(_SendQueueTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + self.data = get_random_bytes(self.args.message_size) + + def run_batch_sync(self) -> int: + batch = self.sender.create_message_batch() + for _ in range(self.args.batch_size): + batch.add_message(ServiceBusMessage(self.data)) + self.sender.send_messages(batch) + return self.args.batch_size + + async def run_batch_async(self) -> int: + batch = await self.async_sender.create_message_batch() + for _ in range(self.args.batch_size): + batch.add_message(ServiceBusMessage(self.data)) + + await self.async_sender.send_messages(batch) + return self.args.batch_size \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message.py new file mode 100644 index 0000000000000..af8dd5e7dc7e9 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from ._test_base import _SendTopicTest + +from azure_devtools.perfstress_tests import get_random_bytes + +from azure.servicebus import ServiceBusMessage + +class SendTopicMessageTest(_SendTopicTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + self.data = get_random_bytes(self.args.message_size) + + def run_batch_sync(self) -> int: + if self.args.batch_size > 1: + self.sender.send_messages( + [ServiceBusMessage(self.data) for _ in range(self.args.batch_size)] + ) + else: + self.sender.send_messages(ServiceBusMessage(self.data)) + + return self.args.batch_size + + async def run_batch_async(self) -> int: + if self.args.batch_size > 1: + await self.async_sender.send_messages( + [ServiceBusMessage(self.data) for _ in range(self.args.batch_size)] + ) + else: + await self.async_sender.send_messages(ServiceBusMessage(self.data)) + + return self.args.batch_size diff --git a/sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message_batch.py b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message_batch.py new file mode 100644 index 0000000000000..08267f78bdec4 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/perf_tests/send_topic_message_batch.py @@ -0,0 +1,30 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from ._test_base import _SendTopicTest + +from azure_devtools.perfstress_tests import get_random_bytes + +from azure.servicebus import ServiceBusMessage + + +class SendTopicMessageBatchTest(_SendTopicTest): + def __init__(self, arguments) -> None: + super().__init__(arguments) + self.data = get_random_bytes(self.args.message_size) + + def run_batch_sync(self) -> int: + batch = self.sender.create_message_batch() + for _ in range(self.args.batch_size): + batch.add_message(ServiceBusMessage(self.data)) + self.sender.send_messages(batch) + return self.args.batch_size + + async def run_batch_async(self) -> int: + batch = await self.async_sender.create_message_batch() + for _ in range(self.args.batch_size): + batch.add_message(ServiceBusMessage(self.data)) + await self.async_sender.send_messages(batch) + return self.args.batch_size \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/tests/servicebus_preparer.py b/sdk/servicebus/azure-servicebus/tests/servicebus_preparer.py index 8c832411b0c01..2a34569e14dfb 100644 --- a/sdk/servicebus/azure-servicebus/tests/servicebus_preparer.py +++ b/sdk/servicebus/azure-servicebus/tests/servicebus_preparer.py @@ -16,7 +16,6 @@ from devtools_testutils import ( AzureMgmtPreparer, FakeResource, get_region_override, add_general_regex_sanitizer ) - from devtools_testutils.resource_testcase import RESOURCE_GROUP_PARAM SERVICEBUS_DEFAULT_AUTH_RULE_NAME = 'RootManageSharedAccessKey' @@ -183,11 +182,6 @@ def create_resource(self, name, **kwargs): self.connection_string = key.primary_connection_string self.key_name = key.key_name self.primary_key = key.primary_key - - self.test_class_instance.scrubber.register_name_pair( - name, - self.resource_moniker - ) else: self.resource = FakeResource(name=name, id=name) self.connection_string = f"Endpoint=sb://{name}{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" @@ -195,9 +189,9 @@ def create_resource(self, name, **kwargs): self.primary_key = 'ZmFrZV9hY29jdW50X2tleQ==' return { self.parameter_name: self.resource, - '{}_connection_string'.format(self.parameter_name): self.connection_string, - '{}_key_name'.format(self.parameter_name): self.key_name, - '{}_primary_key'.format(self.parameter_name): self.primary_key, + f'{self.parameter_name}_connection_string': self.connection_string, + f'{self.parameter_name}_key_name': self.key_name, + f'{self.parameter_name}_primary_key': self.primary_key, } def remove_resource(self, name, **kwargs): @@ -290,10 +284,6 @@ def create_resource(self, name, **kwargs): raise time.sleep(3) - self.test_class_instance.scrubber.register_name_pair( - name, - self.resource_moniker - ) else: self.resource = FakeResource(name=name, id=name) return { @@ -360,10 +350,6 @@ def create_resource(self, name, **kwargs): raise time.sleep(3) - self.test_class_instance.scrubber.register_name_pair( - name, - self.resource_moniker - ) else: self.resource = FakeResource(name=name, id=name) return { @@ -443,10 +429,6 @@ def create_resource(self, name, **kwargs): raise time.sleep(3) - self.test_class_instance.scrubber.register_name_pair( - name, - self.resource_moniker - ) else: self.resource = FakeResource(name=name, id=name) return { @@ -508,10 +490,6 @@ def create_resource(self, name, **kwargs): key = self.client.namespaces.list_keys(group.name, namespace.name, name) connection_string = key.primary_connection_string - self.test_class_instance.scrubber.register_name_pair( - name, - self.resource_moniker - ) else: self.resource = FakeResource(name=name, id=name) connection_string = 'https://microsoft.com' @@ -579,10 +557,6 @@ def create_resource(self, name, **kwargs): key = self.client.queues.list_keys(group.name, namespace.name, queue.name, name) connection_string = key.primary_connection_string - self.test_class_instance.scrubber.register_name_pair( - name, - self.resource_moniker - ) else: self.resource = FakeResource(name=name, id=name) connection_string = 'https://microsoft.com' diff --git a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py index 6877c1171ce4a..00e3e46a65f9d 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py +++ b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py @@ -11,9 +11,9 @@ parse_connection_string, ) -from devtools_testutils import AzureMgmtTestCase +from devtools_testutils import AzureMgmtRecordedTestCase -class ServiceBusConnectionStringParserTests(AzureMgmtTestCase): +class ServiceBusConnectionStringParserTests(AzureMgmtRecordedTestCase): def test_sb_conn_str_parse_cs(self, **kwargs): conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' parse_result = parse_connection_string(conn_str) diff --git a/sdk/servicebus/azure-servicebus/tests/test_message.py b/sdk/servicebus/azure-servicebus/tests/test_message.py index 2b20c443555b0..f1b16e39bc468 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_message.py +++ b/sdk/servicebus/azure-servicebus/tests/test_message.py @@ -1,6 +1,19 @@ -import uamqp +import os +import pytest +try: + import uamqp + from azure.servicebus._transport._uamqp_transport import UamqpTransport +except (ModuleNotFoundError, ImportError): + uamqp = None from datetime import datetime, timedelta -from azure.servicebus import ServiceBusMessage, ServiceBusReceivedMessage, ServiceBusMessageState +from azure.servicebus import ( + ServiceBusClient, + ServiceBusMessage, + ServiceBusReceivedMessage, + ServiceBusMessageState, + ServiceBusReceiveMode, + ServiceBusMessageBatch +) from azure.servicebus._common.constants import ( _X_OPT_PARTITION_KEY, _X_OPT_VIA_PARTITION_KEY, @@ -12,7 +25,15 @@ AmqpMessageProperties, AmqpMessageHeader ) +from azure.servicebus._pyamqp.message import Message +from azure.servicebus._pyamqp._message_backcompat import LegacyBatchMessage +from azure.servicebus._transport._pyamqp_transport import PyamqpTransport + +from devtools_testutils import AzureMgmtRecordedTestCase, CachedResourceGroupPreparer +from servicebus_preparer import CachedServiceBusNamespacePreparer, ServiceBusQueuePreparer +from utilities import uamqp_transport as get_uamqp_transport, ArgPasser +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() def test_servicebus_message_repr(): message = ServiceBusMessage("hello") @@ -39,97 +60,197 @@ def test_servicebus_message_repr_with_props(): assert "application_properties={'prop': 'test'}, session_id=id_session," in message.__repr__() assert "content_type=content type, correlation_id=correlation, to=forward to, reply_to=reply to, reply_to_session_id=reply to session, subject=github, time_to_live=0:00:30, partition_key=id_session, scheduled_enqueue_time_utc" in message.__repr__() - -def test_servicebus_received_message_repr(): - uamqp_received_message = uamqp.message.Message( - body=b'data', - annotations={ +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) +def test_servicebus_received_message_repr(uamqp_transport): + my_frame = [0,0,0] + if uamqp_transport: + received_message = uamqp.message.Message( + body=[b'data'], + annotations={ + _X_OPT_PARTITION_KEY: b'r_key', + _X_OPT_VIA_PARTITION_KEY: b'r_via_key', + _X_OPT_SCHEDULED_ENQUEUE_TIME: 123424566, + }, + properties={} + ) + else: + received_message = Message( + data=[b'data'], + message_annotations={ _X_OPT_PARTITION_KEY: b'r_key', _X_OPT_VIA_PARTITION_KEY: b'r_via_key', _X_OPT_SCHEDULED_ENQUEUE_TIME: 123424566, }, - properties=uamqp.message.MessageProperties() + properties={} ) - received_message = ServiceBusReceivedMessage(uamqp_received_message, receiver=None) + received_message = ServiceBusReceivedMessage(received_message, receiver=None, frame=my_frame) repr_str = received_message.__repr__() assert "application_properties=None, session_id=None" in repr_str - assert "content_type=None, correlation_id=None, to=None, reply_to=None, reply_to_session_id=None, subject=None," + assert "content_type=None, correlation_id=None, to=None, reply_to=None, reply_to_session_id=None, subject=None," in repr_str assert "partition_key=r_key, scheduled_enqueue_time_utc" in repr_str -def test_servicebus_received_state(): - uamqp_received_message = uamqp.message.Message( - body=b'data', - annotations={ - b"x-opt-message-state": 3 - }, - properties=uamqp.message.MessageProperties() - ) - received_message = ServiceBusReceivedMessage(uamqp_received_message, receiver=None) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) +def test_servicebus_received_state(uamqp_transport): + my_frame = [0,0,0] + if uamqp_transport: + amqp_received_message = uamqp.message.Message( + body=[b'data'], + annotations={ + b"x-opt-message-state": 3 + }, + ) + else: + amqp_received_message = Message( + data=[b'data'], + message_annotations={ + b"x-opt-message-state": 3 + }, + ) + received_message = ServiceBusReceivedMessage(amqp_received_message, receiver=None, frame=my_frame) assert received_message.state == 3 - uamqp_received_message = uamqp.message.Message( - body=b'data', - annotations={ - b"x-opt-message-state": 1 - }, - properties=uamqp.message.MessageProperties() - ) - received_message = ServiceBusReceivedMessage(uamqp_received_message, receiver=None) + if uamqp_transport: + amqp_received_message = uamqp.message.Message( + body=[b'data'], + annotations={ + b"x-opt-message-state": 1 + }, + properties={} + ) + else: + amqp_received_message = Message( + data=[b'data'], + message_annotations={ + b"x-opt-message-state": 1 + }, + properties={} + ) + received_message = ServiceBusReceivedMessage(amqp_received_message, receiver=None) assert received_message.state == ServiceBusMessageState.DEFERRED - uamqp_received_message = uamqp.message.Message( - body=b'data', - annotations={ - }, - properties=uamqp.message.MessageProperties() - ) - received_message = ServiceBusReceivedMessage(uamqp_received_message, receiver=None) + if uamqp_transport: + amqp_received_message = uamqp.message.Message( + body=[b'data'], + annotations={ + }, + properties={} + ) + else: + amqp_received_message = Message( + data=[b'data'], + message_annotations={ + }, + properties={} + ) + received_message = ServiceBusReceivedMessage(amqp_received_message, receiver=None) assert received_message.state == ServiceBusMessageState.ACTIVE - uamqp_received_message = uamqp.message.Message( - body=b'data', - properties=uamqp.message.MessageProperties() - ) - received_message = ServiceBusReceivedMessage(uamqp_received_message, receiver=None) + if uamqp_transport: + amqp_received_message = uamqp.message.Message( + body=[b'data'], + properties={} + ) + else: + amqp_received_message = Message( + data=[b'data'], + properties={} + ) + received_message = ServiceBusReceivedMessage(amqp_received_message, receiver=None) assert received_message.state == ServiceBusMessageState.ACTIVE - uamqp_received_message = uamqp.message.Message( - body=b'data', - annotations={ - b"x-opt-message-state": 0 - }, - properties=uamqp.message.MessageProperties() - ) - received_message = ServiceBusReceivedMessage(uamqp_received_message, receiver=None) + if uamqp_transport: + amqp_received_message = uamqp.message.Message( + body=[b'data'], + annotations={ + b"x-opt-message-state": 0 + }, + properties={} + ) + else: + amqp_received_message = Message( + data=[b'data'], + message_annotations={ + b"x-opt-message-state": 0 + }, + properties={} + ) + received_message = ServiceBusReceivedMessage(amqp_received_message, receiver=None) assert received_message.state == ServiceBusMessageState.ACTIVE -def test_servicebus_received_message_repr_with_props(): - uamqp_received_message = uamqp.message.Message( - body=b'data', - annotations={ - _X_OPT_PARTITION_KEY: b'r_key', - _X_OPT_VIA_PARTITION_KEY: b'r_via_key', - _X_OPT_SCHEDULED_ENQUEUE_TIME: 123424566, - }, - properties=uamqp.message.MessageProperties( - message_id="id_message", - absolute_expiry_time=100, - content_type="content type", - correlation_id="correlation", - subject="github", - group_id="id_session", - reply_to="reply to", - reply_to_group_id="reply to group" - ) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) +def test_servicebus_received_message_repr_with_props(uamqp_transport): + my_frame = [0,0,0] + properties = AmqpMessageProperties( + message_id="id_message", + absolute_expiry_time=100, + content_type="content type", + correlation_id="correlation", + subject="github", + group_id="id_session", + reply_to="reply to", + reply_to_group_id="reply to group" ) + message_annotations = { + _X_OPT_PARTITION_KEY: b'r_key', + _X_OPT_VIA_PARTITION_KEY: b'r_via_key', + _X_OPT_SCHEDULED_ENQUEUE_TIME: 123424566, + } + data = [b'data'] + if uamqp_transport: + amqp_received_message = uamqp.message.Message( + body=data, + annotations= message_annotations, + properties=properties + ) + else: + amqp_received_message = Message( + data=data, + message_annotations= message_annotations, + properties=properties + ) received_message = ServiceBusReceivedMessage( - message=uamqp_received_message, + message=amqp_received_message, receiver=None, + frame=my_frame ) assert "application_properties=None, session_id=id_session" in received_message.__repr__() assert "content_type=content type, correlation_id=correlation, to=None, reply_to=reply to, reply_to_session_id=reply to group, subject=github" in received_message.__repr__() assert "partition_key=r_key, scheduled_enqueue_time_utc" in received_message.__repr__() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) +def test_servicebus_message_batch(uamqp_transport): + if uamqp_transport: + amqp_transport=UamqpTransport + else: + amqp_transport=PyamqpTransport + + batch = ServiceBusMessageBatch(max_size_in_bytes=240, partition_key="par", amqp_transport=amqp_transport) + batch.add_message( + ServiceBusMessage( + "A", + application_properties={b"val1": b"a", "val2": "b"}, + session_id="session_id", + message_id="message_id", + scheduled_enqueue_time_utc=datetime.now(), + time_to_live=timedelta(seconds=60), + content_type="content_type", + correlation_id="cid", + subject="sub", + partition_key="session_id", + to="to", + reply_to="reply_to", + reply_to_session_id="reply_to_session_id" + ) + ) + assert str(batch) == "ServiceBusMessageBatch(max_size_in_bytes=240, message_count=1)" + assert repr(batch) == "ServiceBusMessageBatch(max_size_in_bytes=240, message_count=1)" + + assert batch.size_in_bytes == 238 and len(batch) == 1 + + with pytest.raises(ValueError): + batch.add_message(ServiceBusMessage("A")) + + assert batch.message def test_amqp_message(): sb_message = ServiceBusMessage(body=None) @@ -242,3 +363,474 @@ def test_servicebus_message_time_to_live(): assert message.time_to_live == timedelta(seconds=30) message.time_to_live = timedelta(days=1) assert message.time_to_live == timedelta(days=1) + + + +class TestServiceBusMessageBackcompat(AzureMgmtRecordedTestCase): + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_message_backcompat_receive_and_delete_databody(self, uamqp_transport, *, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + queue_name = servicebus_queue.name + outgoing_message = ServiceBusMessage( + body="hello", + application_properties={'prop': 'test'}, + session_id="id_session", + message_id="id_message", + time_to_live=timedelta(seconds=30), + content_type="content type", + correlation_id="correlation", + subject="github", + partition_key="id_session", + to="forward to", + reply_to="reply to", + reply_to_session_id="reply to session" + ) + + sb_client = ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=True, uamqp_transport=uamqp_transport) + with sb_client.get_queue_sender(queue_name) as sender: + sender.send_messages(outgoing_message) + + # outgoing_message.message will be LegacyMessage for both uamqp and pyamqp same as in EH. + # Previously, "empty"/useless uamqp.Message was returned, # b/c outgoing message is constructed + # in send. So, returning LegacyMessage now should not cause issues. + assert outgoing_message.message + with pytest.raises(TypeError): + outgoing_message.message.accept() + with pytest.raises(TypeError): + outgoing_message.message.release() + with pytest.raises(TypeError): + outgoing_message.message.reject() + with pytest.raises(TypeError): + outgoing_message.message.modify(True, True) + try: + assert outgoing_message.message.state == uamqp.constants.MessageState.SendComplete + except AttributeError: # uamqp not installed + pass + assert outgoing_message.message.settled + assert outgoing_message.message.delivery_annotations is None + assert outgoing_message.message.delivery_no is None + assert outgoing_message.message.delivery_tag is None + assert outgoing_message.message.on_send_complete is None + assert outgoing_message.message.footer is None + assert outgoing_message.message.retries >= 0 + assert outgoing_message.message.idle_time >= 0 + with pytest.raises(Exception): + outgoing_message.message.gather() + assert isinstance(outgoing_message.message.encode_message(), bytes) + assert outgoing_message.message.get_message_encoded_size() == 208 + assert list(outgoing_message.message.get_data()) == [b'hello'] + assert outgoing_message.message.application_properties == {'prop': 'test'} + assert outgoing_message.message.get_message() # C instance. + assert len(outgoing_message.message.annotations) == 1 + assert list(outgoing_message.message.annotations.values())[0] == 'id_session' + assert str(outgoing_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) + assert outgoing_message.message.header.get_header_obj().delivery_count == 0 + assert outgoing_message.message.properties.message_id == b'id_message' + assert outgoing_message.message.properties.user_id is None + assert outgoing_message.message.properties.to == b'forward to' + assert outgoing_message.message.properties.subject == b'github' + assert outgoing_message.message.properties.reply_to == b'reply to' + assert outgoing_message.message.properties.correlation_id == b'correlation' + assert outgoing_message.message.properties.content_type == b'content type' + assert outgoing_message.message.properties.content_encoding is None + assert outgoing_message.message.properties.absolute_expiry_time + assert outgoing_message.message.properties.creation_time + assert outgoing_message.message.properties.group_id == b'id_session' + assert outgoing_message.message.properties.group_sequence is None + assert outgoing_message.message.properties.reply_to_group_id == b'reply to session' + assert outgoing_message.message.properties.get_properties_obj().message_id + + with sb_client.get_queue_receiver(queue_name, + receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, + max_wait_time=10) as receiver: + batch = receiver.receive_messages() + incoming_message = batch[0] + assert incoming_message.message + try: + assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled + except AttributeError: # uamqp not installed + pass + assert incoming_message.message.settled + assert incoming_message.message.delivery_annotations == {} + assert incoming_message.message.delivery_no >= 1 + assert incoming_message.message.delivery_tag is None + assert incoming_message.message.on_send_complete is None + assert incoming_message.message.footer is None + assert incoming_message.message.retries >= 0 + assert incoming_message.message.idle_time == 0 + with pytest.raises(Exception): + incoming_message.message.gather() + assert isinstance(incoming_message.message.encode_message(), bytes) + # TODO: uamqp size = pyamqp size + 4? + # uamqp bug accounts for 3 bytes: + # - durable/first_acquirer/priority set by default in uamqp, None in pyamqp + # - setting pyamqp values for durable/first_acquirer increases pyamqp size = 269 + if uamqp_transport: + encoded_size = 267 + else: + encoded_size = 263 + assert incoming_message.message.get_message_encoded_size() == encoded_size + assert list(incoming_message.message.get_data()) == [b'hello'] + assert incoming_message.message.application_properties == {b'prop': b'test'} + assert incoming_message.message.get_message() # C instance. + assert len(incoming_message.message.annotations) == 3 + assert incoming_message.message.annotations[b'x-opt-enqueued-time'] > 0 + assert incoming_message.message.annotations[b'x-opt-sequence-number'] > 0 + assert incoming_message.message.annotations[b'x-opt-partition-key'] == b'id_session' + if uamqp_transport: + # uamqp bugs: + # 1) in uamqp.get_header_obj(): + # delivery_count should be 0, but b/c header obj is not mutable, value is not replaced. + # 2) MessageHeader.durable/first_acquirer are being set to True always on received message + # These properties should not be modified by uamqp. + assert ", 'time_to_live': 30000" in str(incoming_message.message.header) + else: + assert incoming_message.message.header.get_header_obj().delivery_count == 0 + assert ", 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}" in str(incoming_message.message.header) + assert incoming_message.message.properties.message_id == b'id_message' + assert incoming_message.message.properties.user_id is None + assert incoming_message.message.properties.to == b'forward to' + assert incoming_message.message.properties.subject == b'github' + assert incoming_message.message.properties.reply_to == b'reply to' + assert incoming_message.message.properties.correlation_id == b'correlation' + assert incoming_message.message.properties.content_type == b'content type' + assert incoming_message.message.properties.content_encoding is None + assert incoming_message.message.properties.absolute_expiry_time + assert incoming_message.message.properties.creation_time + assert incoming_message.message.properties.group_id == b'id_session' + assert incoming_message.message.properties.group_sequence is None + assert incoming_message.message.properties.reply_to_group_id == b'reply to session' + assert incoming_message.message.properties.get_properties_obj().message_id + assert not incoming_message.message.accept() + assert not incoming_message.message.release() + assert not incoming_message.message.reject() + assert not incoming_message.message.modify(True, True) + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_message_backcompat_peek_lock_databody(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + queue_name = servicebus_queue.name + outgoing_message = ServiceBusMessage( + body="hello", + application_properties={'prop': 'test'}, + session_id="id_session", + message_id="id_message", + time_to_live=timedelta(seconds=30), + content_type="content type", + correlation_id="correlation", + subject="github", + partition_key="id_session", + to="forward to", + reply_to="reply to", + reply_to_session_id="reply to session" + ) + + sb_client = ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=True, uamqp_transport=uamqp_transport) + with sb_client.get_queue_sender(queue_name) as sender: + sender.send_messages(outgoing_message) + + assert outgoing_message.message + with pytest.raises(TypeError): + outgoing_message.message.accept() + with pytest.raises(TypeError): + outgoing_message.message.release() + with pytest.raises(TypeError): + outgoing_message.message.reject() + with pytest.raises(TypeError): + outgoing_message.message.modify(True, True) + try: + assert outgoing_message.message.state == uamqp.constants.MessageState.SendComplete + except AttributeError: # uamqp not installed + pass + assert outgoing_message.message.settled + assert outgoing_message.message.delivery_annotations is None + assert outgoing_message.message.delivery_no is None + assert outgoing_message.message.delivery_tag is None + assert outgoing_message.message.on_send_complete is None + assert outgoing_message.message.footer is None + assert outgoing_message.message.retries >= 0 + assert outgoing_message.message.idle_time >= 0 + with pytest.raises(Exception): + outgoing_message.message.gather() + assert isinstance(outgoing_message.message.encode_message(), bytes) + assert outgoing_message.message.get_message_encoded_size() == 208 + assert list(outgoing_message.message.get_data()) == [b'hello'] + assert outgoing_message.message.application_properties == {'prop': 'test'} + assert outgoing_message.message.get_message() # C instance. + assert len(outgoing_message.message.annotations) == 1 + assert list(outgoing_message.message.annotations.values())[0] == 'id_session' + assert str(outgoing_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) + assert outgoing_message.message.header.get_header_obj().delivery_count == 0 + assert outgoing_message.message.properties.message_id == b'id_message' + assert outgoing_message.message.properties.user_id is None + assert outgoing_message.message.properties.to == b'forward to' + assert outgoing_message.message.properties.subject == b'github' + assert outgoing_message.message.properties.reply_to == b'reply to' + assert outgoing_message.message.properties.correlation_id == b'correlation' + assert outgoing_message.message.properties.content_type == b'content type' + assert outgoing_message.message.properties.content_encoding is None + assert outgoing_message.message.properties.absolute_expiry_time + assert outgoing_message.message.properties.creation_time + assert outgoing_message.message.properties.group_id == b'id_session' + assert outgoing_message.message.properties.group_sequence is None + assert outgoing_message.message.properties.reply_to_group_id == b'reply to session' + assert outgoing_message.message.properties.get_properties_obj().message_id + + with sb_client.get_queue_receiver(queue_name, + receive_mode=ServiceBusReceiveMode.PEEK_LOCK, + max_wait_time=10) as receiver: + batch = receiver.receive_messages() + incoming_message = batch[0] + assert incoming_message.message + try: + assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled + except AttributeError: # uamqp not installed + pass + assert not incoming_message.message.settled + assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] + assert incoming_message.message.delivery_no >= 1 + assert incoming_message.message.delivery_tag + assert incoming_message.message.on_send_complete is None + assert incoming_message.message.footer is None + assert incoming_message.message.retries >= 0 + assert incoming_message.message.idle_time == 0 + with pytest.raises(Exception): + incoming_message.message.gather() + assert isinstance(incoming_message.message.encode_message(), bytes) + + if uamqp_transport: + encoded_size = 334 + else: + # uamqp bug: sets durable/first_acquirer/priority by default on incoming message + # pyamqp = 339 if durable/first_acquirer set + encoded_size = 333 + assert incoming_message.message.get_message_encoded_size() == encoded_size + assert list(incoming_message.message.get_data()) == [b'hello'] + assert incoming_message.message.application_properties == {b'prop': b'test'} + assert incoming_message.message.get_message() # C instance. + assert len(incoming_message.message.annotations) == 4 + assert incoming_message.message.annotations[b'x-opt-enqueued-time'] > 0 + assert incoming_message.message.annotations[b'x-opt-sequence-number'] > 0 + assert incoming_message.message.annotations[b'x-opt-partition-key'] == b'id_session' + assert incoming_message.message.annotations[b'x-opt-locked-until'] + if uamqp_transport: + # uamqp bugs: + # 1) in uamqp.get_header_obj(): + # delivery_count should be 0, but b/c header obj is not mutable, value is not replaced. + # 2) MessageHeader.durable/first_acquirer are always being set to True on received message + # These properties should not be modified by uamqp. By default, should be None when not set. + assert ", 'time_to_live': 30000" in str(incoming_message.message.header) + else: + assert incoming_message.message.header.get_header_obj().delivery_count == 0 + assert str(incoming_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) + assert incoming_message.message.properties.message_id == b'id_message' + assert incoming_message.message.properties.user_id is None + assert incoming_message.message.properties.to == b'forward to' + assert incoming_message.message.properties.subject == b'github' + assert incoming_message.message.properties.reply_to == b'reply to' + assert incoming_message.message.properties.correlation_id == b'correlation' + assert incoming_message.message.properties.content_type == b'content type' + assert incoming_message.message.properties.content_encoding is None + assert incoming_message.message.properties.absolute_expiry_time + assert incoming_message.message.properties.creation_time + assert incoming_message.message.properties.group_id == b'id_session' + assert incoming_message.message.properties.group_sequence is None + assert incoming_message.message.properties.reply_to_group_id == b'reply to session' + assert incoming_message.message.properties.get_properties_obj().message_id + assert incoming_message.message.accept() + # TODO: State isn't updated if settled correctly via the receiver. + try: + assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled + except AttributeError: # uamqp not installed + pass + assert incoming_message.message.settled + assert not incoming_message.message.release() + assert not incoming_message.message.reject() + assert not incoming_message.message.modify(True, True) + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_message_backcompat_receive_and_delete_valuebody(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + queue_name = servicebus_queue.name + outgoing_message = AmqpAnnotatedMessage(value_body={b"key": b"value"}) + + with pytest.raises(AttributeError): + outgoing_message.message + + sb_client = ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) + with sb_client.get_queue_sender(queue_name) as sender: + sender.send_messages(outgoing_message) + + with pytest.raises(AttributeError): + outgoing_message.message + + with sb_client.get_queue_receiver(queue_name, + receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, + max_wait_time=10) as receiver: + batch = receiver.receive_messages() + incoming_message = batch[0] + assert incoming_message.message + try: + assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled + except AttributeError: # uamqp not installed + pass + assert incoming_message.message.settled + with pytest.raises(Exception): + incoming_message.message.gather() + assert incoming_message.message.get_data() == {b"key": b"value"} + assert not incoming_message.message.accept() + assert not incoming_message.message.release() + assert not incoming_message.message.reject() + assert not incoming_message.message.modify(True, True) + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_message_backcompat_peek_lock_valuebody(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + queue_name = servicebus_queue.name + outgoing_message = AmqpAnnotatedMessage(value_body={b"key": b"value"}) + + with pytest.raises(AttributeError): + outgoing_message.message + + sb_client = ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) + with sb_client.get_queue_sender(queue_name) as sender: + sender.send_messages(outgoing_message) + + with pytest.raises(AttributeError): + outgoing_message.message + + with sb_client.get_queue_receiver(queue_name, + receive_mode=ServiceBusReceiveMode.PEEK_LOCK, + max_wait_time=10) as receiver: + batch = receiver.receive_messages() + incoming_message = batch[0] + assert incoming_message.message + # assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled + assert not incoming_message.message.settled + assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] + assert incoming_message.message.delivery_no >= 1 + assert incoming_message.message.delivery_tag + with pytest.raises(Exception): + incoming_message.message.gather() + assert incoming_message.message.get_data() == {b"key": b"value"} + assert incoming_message.message.accept() + # assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled + assert incoming_message.message.settled + assert not incoming_message.message.release() + assert not incoming_message.message.reject() + assert not incoming_message.message.modify(True, True) + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_message_backcompat_receive_and_delete_sequencebody(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + queue_name = servicebus_queue.name + outgoing_message = AmqpAnnotatedMessage(sequence_body=[1, 2, 3]) + + with pytest.raises(AttributeError): + outgoing_message.message + + sb_client = ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) + with sb_client.get_queue_sender(queue_name) as sender: + sender.send_messages(outgoing_message) + + with pytest.raises(AttributeError): + outgoing_message.message + + with sb_client.get_queue_receiver(queue_name, + receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, + max_wait_time=10) as receiver: + batch = receiver.receive_messages() + incoming_message = batch[0] + assert incoming_message.message + try: + assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled + except AttributeError: # uamqp not installed + pass + assert incoming_message.message.settled + with pytest.raises(Exception): + incoming_message.message.gather() + assert list(incoming_message.message.get_data()) == [[1, 2, 3]] + assert not incoming_message.message.accept() + assert not incoming_message.message.release() + assert not incoming_message.message.reject() + assert not incoming_message.message.modify(True, True) + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_message_batch_backcompat(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + queue_name = servicebus_queue.name + outgoing_message = AmqpAnnotatedMessage(sequence_body=[1, 2, 3]) + + with pytest.raises(AttributeError): + outgoing_message.message + + sb_client = ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) + with sb_client.get_queue_sender(queue_name) as sender: + sender.send_messages(outgoing_message) + + with pytest.raises(AttributeError): + outgoing_message.message + + with sb_client.get_queue_receiver(queue_name, + receive_mode=ServiceBusReceiveMode.PEEK_LOCK, + max_wait_time=10) as receiver: + batch = receiver.receive_messages() + incoming_message = batch[0] + assert incoming_message.message + try: + assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled + except AttributeError: # uamqp not installed + pass + assert not incoming_message.message.settled + assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] + assert incoming_message.message.delivery_no >= 1 + assert incoming_message.message.delivery_tag + with pytest.raises(Exception): + incoming_message.message.gather() + assert list(incoming_message.message.get_data()) == [[1, 2, 3]] + assert incoming_message.message.accept() + try: + assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled + except AttributeError: # uamqp not installed + pass + assert incoming_message.message.settled + assert not incoming_message.message.release() + assert not incoming_message.message.reject() + assert not incoming_message.message.modify(True, True) diff --git a/sdk/servicebus/azure-servicebus/tests/test_queues.py b/sdk/servicebus/azure-servicebus/tests/test_queues.py index bbe3805cfd2f6..345c677a44cd3 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_queues.py +++ b/sdk/servicebus/azure-servicebus/tests/test_queues.py @@ -7,6 +7,7 @@ import logging import sys import os +import json from concurrent.futures import ThreadPoolExecutor import types import pytest @@ -15,10 +16,17 @@ from datetime import datetime, timedelta import calendar import unittest +import pickle -import uamqp -import uamqp.errors -from uamqp import compat +try: + import uamqp + from azure.servicebus._transport._uamqp_transport import UamqpTransport +except ImportError: + uamqp = None + +from azure.servicebus._transport._pyamqp_transport import PyamqpTransport +from azure.servicebus._pyamqp.message import Message +from azure.servicebus._pyamqp import error, client, management_operation from azure.servicebus import ( ServiceBusClient, AutoLockRenewer, @@ -55,7 +63,7 @@ OperationTimeoutError ) -from devtools_testutils import AzureMgmtTestCase +from devtools_testutils import AzureMgmtRecordedTestCase from servicebus_preparer import ( CachedServiceBusNamespacePreparer, ServiceBusQueuePreparer, @@ -64,6 +72,9 @@ ) from utilities import get_logger, print_message, sleep_until_expired from mocks import MockReceivedMessage, MockReceiver +from utilities import uamqp_transport as get_uamqp_transport, ArgPasser + +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() _logger = get_logger(logging.DEBUG) @@ -72,17 +83,19 @@ # Old servicebus tests were not written to work on both stubs and live entities. # This disables those tests for non-live scenarios, and should be removed as tests # are ported to offline-compatible code. -class ServiceBusQueueTests(AzureMgmtTestCase): +class TestServiceBusQueue(AzureMgmtRecordedTestCase): @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_receive_and_delete_reconnect_interaction(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_receive_and_delete_reconnect_interaction(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): # Note: This test was to guard against github issue 7079 sb_client = ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(5): @@ -99,14 +112,17 @@ def test_receive_and_delete_reconnect_interaction(self, servicebus_namespace_con count += 1 assert count == 5 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT5S') - def test_github_issue_6178(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_github_issue_6178(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(3): @@ -121,14 +137,17 @@ def test_github_issue_6178(self, servicebus_namespace_connection_string, service receiver.complete_message(message) time.sleep(10) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - def test_queue_by_queue_client_conn_str_receive_handler_peeklock(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_queue_client_conn_str_receive_handler_peeklock(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) for i in range(10): @@ -170,9 +189,6 @@ def test_queue_by_queue_client_conn_str_receive_handler_peeklock(self, servicebu with pytest.raises(ValueError): receiver.receive_messages(max_wait_time=0) - with pytest.raises(ValueError): - receiver._get_streaming_message_iter(max_wait_time=0) - count = 0 for message in receiver: print_message(_logger, message) @@ -215,9 +231,11 @@ def test_queue_by_queue_client_conn_str_receive_handler_peeklock(self, servicebu @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - def test_queue_by_queue_client_conn_str_receive_handler_release_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_queue_client_conn_str_receive_handler_release_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: def sub_test_releasing_messages(): # test releasing messages when prefetch is 1 and link credits are issue dynamically @@ -317,21 +335,32 @@ def sub_test_releasing_messages_iterator(): receiver.complete_message(msg) assert outter_recv_cnt == 4 + def sub_test_non_releasing_messages(): # test not releasing messages when prefetch is not 1 receiver = sb_client.get_queue_receiver(servicebus_queue.name) sender = sb_client.get_queue_sender(servicebus_queue.name) - def _hack_disable_receive_context_message_received(self, message): - # pylint: disable=protected-access - self._handler._was_message_received = True - self._handler._received_messages.put(message) + if uamqp_transport: + def _hack_disable_receive_context_message_received(self, message): + # pylint: disable=protected-access + self._handler._was_message_received = True + self._handler._received_messages.put(message) + else: + def _hack_disable_receive_context_message_received(self, frame, message): + # pylint: disable=protected-access + self._handler._last_activity_timestamp = time.time() + self._handler._received_messages.put((frame, message)) with sender, receiver: # send 5 msgs to queue first sender.send_messages([ServiceBusMessage('test') for _ in range(5)]) - receiver._handler.message_handler.on_message_received = types.MethodType( - _hack_disable_receive_context_message_received, receiver) + if uamqp_transport: + receiver._handler.message_handler.on_message_received = types.MethodType( + _hack_disable_receive_context_message_received, receiver) + else: + receiver._handler._link._on_transfer = types.MethodType( + _hack_disable_receive_context_message_received, receiver) received_msgs = [] while len(received_msgs) < 5: # issue 10 link credits, client should consume 5 msgs from the service @@ -352,6 +381,7 @@ def _hack_disable_receive_context_message_received(self, message): received_msgs.extend(receiver.receive_messages(max_message_count=5, max_wait_time=5)) assert len(received_msgs) == 5 for msg in received_msgs: + # queue ordering I think assert msg.delivery_count == 0 with pytest.raises(ServiceBusError): receiver.complete_message(msg) @@ -370,14 +400,17 @@ def _hack_disable_receive_context_message_received(self, message): sub_test_releasing_messages_iterator() sub_test_non_releasing_messages() + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_queue_client_send_multiple_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_queue_client_send_multiple_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) with sender: @@ -434,15 +467,72 @@ def test_queue_by_queue_client_send_multiple_messages(self, servicebus_namespace with pytest.raises(ValueError): receiver.peek_messages() + sender = sb_client.get_queue_sender(servicebus_queue.name) + receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) + with sender, receiver: + # send previously unpicklable message + msg = { + "body":"W1tdLCB7ImlucHV0X2lkIjogNH0sIHsiY2FsbGJhY2tzIjogbnVsbCwgImVycmJhY2tzIjogbnVsbCwgImNoYWluIjogbnVsbCwgImNob3JkIjogbnVsbH1d", + "content-encoding":"utf-8", + "content-type":"application/json", + "headers":{ + "lang":"py", + "task":"tasks.example_task", + "id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "shadow":None, + "eta":"2021-10-07T02:30:23.764066+00:00", + "expires":None, + "group":None, + "group_index":None, + "retries":1, + "timelimit":[ + None, + None + ], + "root_id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "parent_id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "argsrepr":"()", + "kwargsrepr":"{'input_id': 4}", + "origin":"gen36@94713e01a9c0", + "ignore_result":1, + "x_correlator":"44a1978d-c869-4173-afe4-da741f0edfb9" + }, + "properties":{ + "correlation_id":"7c66557d-e4bc-437f-b021-b66dcc39dfdf", + "reply_to":"7b9a3672-2fed-3e9b-8bfd-23ae2397d9ad", + "origin":"gen68@c33d4eef123a", + "delivery_mode":2, + "delivery_info":{ + "exchange":"", + "routing_key":"celery_task_queue" + }, + "priority":0, + "body_encoding":"base64", + "delivery_tag":"dc83ddb6-8cdc-4413-b88a-06c56cbde90d" + } + } + sender.send_messages(ServiceBusMessage(json.dumps(msg))) + messages = receiver.receive_messages(max_wait_time=10, max_message_count=1) + # complete first then pickle + receiver.complete_message(messages[0]) + if not uamqp_transport: + pickled = pickle.loads(pickle.dumps(messages[0])) + assert json.loads(str(pickled)) == json.loads(str(messages[0])) + else: + with pytest.raises(TypeError): + pickled = pickle.loads(pickle.dumps(messages[0])) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - def test_queue_by_queue_client_conn_str_receive_handler_receiveanddelete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_queue_client_conn_str_receive_handler_receiveanddelete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -484,17 +574,19 @@ def test_queue_by_queue_client_conn_str_receive_handler_receiveanddelete(self, s messages.append(message) assert len(messages) == 0 - + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_queue_client_conn_str_receive_handler_with_stop(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_queue_client_conn_str_receive_handler_with_stop(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -521,15 +613,18 @@ def test_queue_by_queue_client_conn_str_receive_handler_with_stop(self, serviceb assert not receiver._running assert len(messages) == 6 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_iter_messages_simple(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_iter_messages_simple(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10, @@ -554,16 +649,18 @@ def test_queue_by_servicebus_client_iter_messages_simple(self, servicebus_namesp next(receiver) assert count == 10 - + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_conn_str_client_iter_messages_with_abandon(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_conn_str_client_iter_messages_with_abandon(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, receive_mode=ServiceBusReceiveMode.PEEK_LOCK) as receiver: @@ -598,10 +695,12 @@ def test_queue_by_servicebus_conn_str_client_iter_messages_with_abandon(self, se @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_iter_messages_with_defer(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_iter_messages_with_defer(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] with sb_client.get_queue_receiver( @@ -630,16 +729,18 @@ def test_queue_by_servicebus_client_iter_messages_with_defer(self, servicebus_na count += 1 assert count == 0 - + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] with sb_client.get_queue_receiver(servicebus_queue.name, @@ -668,15 +769,18 @@ def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_client( with pytest.raises(ServiceBusError): receiver.receive_deferred_messages(deferred_messages) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: deferred_messages = [] @@ -709,14 +813,17 @@ def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receive receiver.renew_message_lock(message) receiver.complete_message(message) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: deferred_messages = [] for i in range(10): @@ -757,14 +864,17 @@ def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receive receiver.complete_message(message) assert count == 10 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -792,14 +902,17 @@ def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_receive with pytest.raises(ServiceBusError): deferred = receiver.receive_deferred_messages(deferred_messages) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_not_found(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_not_found(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] with sb_client.get_queue_receiver(servicebus_queue.name, @@ -826,15 +939,18 @@ def test_queue_by_servicebus_client_iter_messages_with_retrieve_deferred_not_fou with pytest.raises(ServiceBusError): deferred = receiver.receive_deferred_messages([5, 6, 7]) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_receive_batch_with_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_receive_batch_with_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, @@ -883,15 +999,18 @@ def test_queue_by_servicebus_client_receive_batch_with_deadletter(self, serviceb assert message.application_properties[b'DeadLetterErrorDescription'] == b'Testing description' assert count == 10 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_receive_batch_with_retrieve_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_receive_batch_with_retrieve_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, @@ -939,10 +1058,12 @@ def test_queue_by_servicebus_client_receive_batch_with_retrieve_deadletter(self, @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_session_fail(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_session_fail(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with pytest.raises(ServiceBusError): sb_client.get_queue_receiver(servicebus_queue.name, session_id="test")._open_with_retry() @@ -956,10 +1077,12 @@ def test_queue_by_servicebus_client_session_fail(self, servicebus_namespace_conn @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_browse_messages_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_browse_messages_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(5): @@ -980,10 +1103,12 @@ def test_queue_by_servicebus_client_browse_messages_client(self, servicebus_name @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_browse_messages_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_browse_messages_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, @@ -1043,10 +1168,12 @@ def test_queue_by_servicebus_client_browse_messages_with_receiver(self, serviceb @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_browse_empty_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_browse_empty_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, @@ -1061,10 +1188,12 @@ def test_queue_by_servicebus_client_browse_empty_messages(self, servicebus_names @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_fail_send_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_fail_send_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: too_large = "A" * 256 * 1024 @@ -1081,10 +1210,12 @@ def test_queue_by_servicebus_client_fail_send_messages(self, servicebus_namespac @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_by_servicebus_client_renew_message_locks(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_renew_message_locks(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: messages = [] locks = 3 @@ -1117,16 +1248,19 @@ def test_queue_by_servicebus_client_renew_message_locks(self, servicebus_namespa sleep_until_expired(messages[2]) with pytest.raises(ServiceBusError): receiver.complete_message(messages[2]) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - def test_queue_by_queue_client_conn_str_receive_handler_with_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_queue_client_conn_str_receive_handler_with_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -1235,15 +1369,18 @@ def test_queue_by_queue_client_conn_str_receive_handler_with_autolockrenew(self, assert renewer._is_max_workers_greater_than_one renewer.close() + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - def test_queue_by_queue_client_conn_str_receive_handler_with_auto_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_queue_client_conn_str_receive_handler_with_auto_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: # The 10 iterations is "important" because it gives time for the timed out message to be received again. @@ -1289,15 +1426,18 @@ def test_queue_by_queue_client_conn_str_receive_handler_with_auto_autolockrenew( renewer.close() assert len(messages) == 11 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_message_time_to_live(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_time_to_live(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1323,15 +1463,18 @@ def test_queue_message_time_to_live(self, servicebus_namespace_connection_string count += 1 assert count == 1 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_duplicate_detection=True, dead_lettering_on_message_expiration=True) - def test_queue_message_duplicate_detection(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_duplicate_detection(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: message_id = uuid.uuid4() @@ -1357,10 +1500,12 @@ def test_queue_message_duplicate_detection(self, servicebus_namespace_connection @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_message_connection_closed(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_connection_closed(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1380,10 +1525,12 @@ def test_queue_message_connection_closed(self, servicebus_namespace_connection_s @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_message_expiry(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_expiry(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1413,10 +1560,12 @@ def test_queue_message_expiry(self, servicebus_namespace_connection_string, serv @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_message_lock_renew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_lock_renew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: content = str(uuid.uuid4()) @@ -1444,10 +1593,19 @@ def test_queue_message_lock_renew(self, servicebus_namespace_connection_string, @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT10S') - def test_queue_message_receive_and_delete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_receive_and_delete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp: + transport_type = uamqp.constants.TransportType.Amqp + else: + transport_type = TransportType.Amqp with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, + logging_enable=False, + transport_time=transport_type, + uamqp_transport=uamqp_transport + ) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("Receive and delete test") @@ -1484,10 +1642,12 @@ def test_queue_message_receive_and_delete(self, servicebus_namespace_connection_ @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_message_batch(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_batch(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: def message_content(): for i in range(5): @@ -1505,6 +1665,7 @@ def message_content(): yield message with sb_client.get_queue_sender(servicebus_queue.name) as sender: + # sending manually created message batch (with default pyamqp) should work for both uamqp/pyamqp message = ServiceBusMessageBatch() for each in message_content(): message.add_message(each) @@ -1542,10 +1703,12 @@ def message_content(): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_schedule_message(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_schedule_message(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: scheduled_enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) with sb_client.get_queue_receiver(servicebus_queue.name) as receiver: @@ -1572,20 +1735,21 @@ def test_queue_schedule_message(self, servicebus_namespace_connection_string, se else: raise Exception("Failed to receive schdeduled message.") - @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_schedule_multiple_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_schedule_multiple_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: scheduled_enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) sender = sb_client.get_queue_sender(servicebus_queue.name) - receiver = sb_client.get_queue_receiver(servicebus_queue.name, prefetch_count=20) + receiver = sb_client.get_queue_receiver(servicebus_queue.name, prefetch_count=20, max_wait_time=5) with sender, receiver: content = str(uuid.uuid4()) @@ -1607,7 +1771,7 @@ def test_queue_schedule_multiple_messages(self, servicebus_namespace_connection_ sender.send_messages(message_arry) received_messages = [] - for message in receiver._get_streaming_message_iter(max_wait_time=5): + for message in receiver: received_messages.append(message) receiver.complete_message(message) @@ -1635,6 +1799,22 @@ def test_queue_schedule_multiple_messages(self, servicebus_namespace_connection_ assert messages[0].enqueued_time_utc assert messages[0].message.delivery_tag is not None assert len(messages) == 2 + + if not uamqp_transport: + pickled = pickle.loads(pickle.dumps(messages[0])) + assert pickled.message_id == messages[0].message_id + assert pickled.scheduled_enqueue_time_utc == messages[0].scheduled_enqueue_time_utc + assert pickled.scheduled_enqueue_time_utc <= pickled.enqueued_time_utc.replace(microsecond=0) + assert pickled.delivery_count == messages[0].delivery_count + assert pickled.application_properties == messages[0].application_properties + assert pickled.application_properties[b'key'] == messages[0].application_properties[b'key'] + assert pickled.subject == messages[0].subject + assert pickled.content_type == messages[0].content_type + assert pickled.correlation_id == messages[0].correlation_id + assert pickled.to == messages[0].to + assert pickled.reply_to == messages[0].reply_to + assert pickled.sequence_number + assert pickled.enqueued_time_utc finally: for message in messages: receiver.complete_message(message) @@ -1646,10 +1826,12 @@ def test_queue_schedule_multiple_messages(self, servicebus_namespace_connection_ @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_cancel_scheduled_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_cancel_scheduled_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) with sb_client.get_queue_receiver(servicebus_queue.name) as receiver: @@ -1676,11 +1858,13 @@ def test_queue_cancel_scheduled_messages(self, servicebus_namespace_connection_s @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_message_amqp_over_websocket(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_amqp_over_websocket(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, transport_type=TransportType.AmqpOverWebsocket, - logging_enable=False) as sb_client: + logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: assert sender._config.transport_type == TransportType.AmqpOverWebsocket @@ -1692,7 +1876,8 @@ def test_queue_message_amqp_over_websocket(self, servicebus_namespace_connection messages = receiver.receive_messages(max_wait_time=5) assert len(messages) == 1 - def test_queue_message_http_proxy_setting(self): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_queue_message_http_proxy_setting(self, uamqp_transport): mock_conn_str = "Endpoint=sb://mock.servicebus.windows.net/;SharedAccessKeyName=mock;SharedAccessKey=mock" http_proxy = { 'proxy_hostname': '127.0.0.1', @@ -1701,7 +1886,7 @@ def test_queue_message_http_proxy_setting(self): 'password': '123456' } - sb_client = ServiceBusClient.from_connection_string(mock_conn_str, http_proxy=http_proxy) + sb_client = ServiceBusClient.from_connection_string(mock_conn_str, http_proxy=http_proxy, uamqp_transport=uamqp_transport) assert sb_client._config.http_proxy == http_proxy assert sb_client._config.transport_type == TransportType.AmqpOverWebsocket @@ -1718,10 +1903,12 @@ def test_queue_message_http_proxy_setting(self): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_link(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_link(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False) as sb_client: + logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("Test") @@ -1729,7 +1916,11 @@ def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_link(self with sb_client.get_queue_receiver(servicebus_queue.name) as receiver: messages = receiver.receive_messages(max_wait_time=5) - receiver._handler.message_handler.destroy() # destroy the underlying receiver link + # destroy the underlying receiver link + if uamqp_transport: + receiver._handler.message_handler.destroy() + else: + receiver._handler._link.detach() assert len(messages) == 1 receiver.complete_message(messages[0]) @@ -1848,7 +2039,9 @@ def test_queue_mock_no_reusing_auto_lock_renew(self): with pytest.raises(ServiceBusError): auto_lock_renew.register(receiver, renewable=MockReceivedMessage()) - def test_queue_message_properties(self): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_queue_message_properties(self, uamqp_transport): + scheduled_enqueue_time = (utc_now() + timedelta(seconds=20)).replace(microsecond=0) message = ServiceBusMessage( body='data', @@ -1891,16 +2084,30 @@ def test_queue_message_properties(self): except AttributeError: timestamp = calendar.timegm(new_scheduled_time.timetuple()) * 1000 - uamqp_received_message = uamqp.message.Message( - body=b'data', - annotations={ - _X_OPT_PARTITION_KEY: b'r_key', - _X_OPT_VIA_PARTITION_KEY: b'r_via_key', - _X_OPT_SCHEDULED_ENQUEUE_TIME: timestamp, - }, - properties=uamqp.message.MessageProperties() - ) - received_message = ServiceBusReceivedMessage(uamqp_received_message, receiver=None) + my_frame = [0,0,0] + if uamqp_transport: + amqp_transport = UamqpTransport + amqp_received_message = uamqp.message.Message( + body=[b'data'], + annotations={ + _X_OPT_PARTITION_KEY: b'r_key', + _X_OPT_VIA_PARTITION_KEY: b'r_via_key', + _X_OPT_SCHEDULED_ENQUEUE_TIME: timestamp, + }, + properties={} + ) + else: + amqp_transport = PyamqpTransport + amqp_received_message = Message( + data=[b'data'], + message_annotations={ + _X_OPT_PARTITION_KEY: b'r_key', + _X_OPT_VIA_PARTITION_KEY: b'r_via_key', + _X_OPT_SCHEDULED_ENQUEUE_TIME: timestamp, + }, + properties={} + ) + received_message = ServiceBusReceivedMessage(amqp_received_message, receiver=None, frame=my_frame, amqp_transport=amqp_transport) assert received_message.scheduled_enqueue_time_utc == new_scheduled_time new_scheduled_time = utc_now() + timedelta(hours=1, minutes=49, seconds=32) @@ -1918,26 +2125,28 @@ def test_queue_message_properties(self): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_receive_batch_without_setting_prefetch(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_receive_batch_without_setting_prefetch(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: def message_content(): for i in range(20): yield ServiceBusMessage( body="Test message", - application_properties={'key': 'value'}, + # application_properties={'key': 'value'}, subject='1st', - content_type='application/text', - correlation_id='cid', - message_id='mid', - to='to', - reply_to='reply_to', - time_to_live=timedelta(seconds=60) + # content_type='application/text', + # correlation_id='cid', + # message_id='mid', + # to='to', + # reply_to='reply_to', + # time_to_live=timedelta(seconds=60) ) sender = sb_client.get_queue_sender(servicebus_queue.name) - receiver = sb_client.get_queue_receiver(servicebus_queue.name) + receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10) with sender, receiver: message = ServiceBusMessageBatch() @@ -1950,21 +2159,21 @@ def message_content(): message_2nd_received_cnt = 0 while message_1st_received_cnt < 20 or message_2nd_received_cnt < 20: messages = [] - for message in receiver._get_streaming_message_iter(max_wait_time=10): + for message in receiver: messages.append(message) if not messages: break receive_counter += 1 for message in messages: print_message(_logger, message) - assert b''.join(message.body) == b'Test message' - assert message.application_properties[b'key'] == b'value' - assert message.content_type == 'application/text' - assert message.correlation_id == 'cid' - assert message.message_id == 'mid' - assert message.to == 'to' - assert message.reply_to == 'reply_to' - assert message.time_to_live == timedelta(seconds=60) + # assert b''.join(message.body) == b'Test message' + # assert message.application_properties[b'key'] == b'value' + # assert message.content_type == 'application/text' + # assert message.correlation_id == 'cid' + # assert message.message_id == 'mid' + # assert message.to == 'to' + # assert message.reply_to == 'reply_to' + # assert message.time_to_live == timedelta(seconds=60) if message.subject == '1st': message_1st_received_cnt += 1 @@ -1979,15 +2188,18 @@ def message_content(): # Network/server might be unstable making flow control ineffective in the leading rounds of connection iteration assert receive_counter < 10 # Dynamic link credit issuing come info effect + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_receiver_alive_after_timeout(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_receiver_alive_after_timeout(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False) as sb_client: + logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("0") @@ -1997,11 +2209,11 @@ def test_queue_receiver_alive_after_timeout(self, servicebus_namespace_connectio messages = [] with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: - for message in receiver._get_streaming_message_iter(): + for message in receiver: messages.append(message) break - for message in receiver._get_streaming_message_iter(): + for message in receiver: messages.append(message) for message in messages: @@ -2015,9 +2227,9 @@ def test_queue_receiver_alive_after_timeout(self, servicebus_namespace_connectio message_3 = ServiceBusMessage("3") sender.send_messages([message_2, message_3]) - for message in receiver._get_streaming_message_iter(): + for message in receiver: messages.append(message) - for message in receiver._get_streaming_message_iter(): + for message in receiver: messages.append(message) assert len(messages) == 4 @@ -2030,14 +2242,17 @@ def test_queue_receiver_alive_after_timeout(self, servicebus_namespace_connectio messages = receiver.receive_messages() assert not messages + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True, lock_duration='PT5M') - def test_queue_receive_keep_conn_alive(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_receive_keep_conn_alive(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10) @@ -2065,15 +2280,18 @@ def test_queue_receive_keep_conn_alive(self, servicebus_namespace_connection_str assert len(messages) == 0 # make sure messages are removed from the queue assert receiver_handler == receiver._handler # make sure no reconnection happened + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_receiver_sender_resume_after_link_timeout(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_receiver_sender_resume_after_link_timeout(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False) as sb_client: + logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("0") @@ -2087,67 +2305,26 @@ def test_queue_receiver_sender_resume_after_link_timeout(self, servicebus_namesp messages = [] with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: - for message in receiver._get_streaming_message_iter(): + for message in receiver: messages.append(message) assert len(messages) == 2 - - + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_receiver_respects_max_wait_time_overrides(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, - logging_enable=False) as sb_client: - - with sb_client.get_queue_sender(servicebus_queue.name) as sender: - message = ServiceBusMessage("0") - sender.send_messages(message) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_twice(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp: + transport_type = uamqp.constants.TransportType.AmqpOverWebsocket + else: + transport_type = TransportType.AmqpOverWebsocket - messages = [] - with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: - - time_1 = receiver._handler._counter.get_current_ms() - time_3 = time_1 # In case inner loop isn't hit, fail sanely. - for message in receiver._get_streaming_message_iter(max_wait_time=10): - messages.append(message) - receiver.complete_message(message) - - time_2 = receiver._handler._counter.get_current_ms() - for message in receiver._get_streaming_message_iter(max_wait_time=1): - messages.append(message) - time_3 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=.5) < timedelta(milliseconds=(time_3 - time_2)) <= timedelta(seconds=2) - time_4 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=8) < timedelta(milliseconds=(time_4 - time_3)) <= timedelta(seconds=11) - - for message in receiver._get_streaming_message_iter(max_wait_time=3): - messages.append(message) - time_5 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=1) < timedelta(milliseconds=(time_5 - time_4)) <= timedelta(seconds=4) - - for message in receiver: - messages.append(message) - time_6 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=3) < timedelta(milliseconds=(time_6 - time_5)) <= timedelta(seconds=6) - - for message in receiver._get_streaming_message_iter(): - messages.append(message) - time_7 = receiver._handler._counter.get_current_ms() - assert timedelta(seconds=3) < timedelta(milliseconds=(time_7 - time_6)) <= timedelta(seconds=6) - assert len(messages) == 1 - - - @pytest.mark.liveTest - @pytest.mark.live_test_only - @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') - @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') - @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_send_twice(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, + transport_type=transport_type, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("ServiceBusMessage") @@ -2166,12 +2343,17 @@ def test_queue_send_twice(self, servicebus_namespace_connection_string, serviceb # then normal message resending sender.send_messages(message) sender.send_messages(message) + expected_count = 2 + if not uamqp_transport: + pickled_recvd = pickle.loads(pickle.dumps(messages[0])) + sender.send_messages(pickled_recvd) + expected_count = 3 messages = [] with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=20) as receiver: for message in receiver: messages.append(message) receiver.complete_message(message) - assert len(messages) == 2 + assert len(messages) == expected_count @pytest.mark.liveTest @@ -2179,10 +2361,12 @@ def test_queue_send_twice(self, servicebus_namespace_connection_string, serviceb @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_receiver_invalid_mode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_receiver_invalid_mode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with pytest.raises(ValueError): with sb_client.get_queue_receiver(servicebus_queue.name, receive_mode=2) as receiver: @@ -2194,10 +2378,12 @@ def test_queue_receiver_invalid_mode(self, servicebus_namespace_connection_strin @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_receiver_invalid_autolockrenew_mode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_receiver_invalid_autolockrenew_mode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with pytest.raises(ValueError): with sb_client.get_queue_receiver(servicebus_queue.name, receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, @@ -2210,7 +2396,9 @@ def test_queue_receiver_invalid_autolockrenew_mode(self, servicebus_namespace_co @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_message_inner_amqp_properties(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_message_inner_amqp_properties(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): message = ServiceBusMessage("body") @@ -2222,7 +2410,7 @@ def test_message_inner_amqp_properties(self, servicebus_namespace_connection_str message.raw_amqp_message.footer = {b"footer":6} with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(message) @@ -2247,21 +2435,32 @@ def test_message_inner_amqp_properties(self, servicebus_namespace_connection_str @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_send_timeout(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - def _hack_amqp_sender_run(cls): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_timeout(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + def _hack_amqp_sender_run(self, **kwargs): time.sleep(6) # sleep until timeout - cls.message_handler.work() - cls._waiting_messages = 0 - cls._pending_messages = cls._filter_pending() - if cls._backoff and not cls._waiting_messages: - _logger.info("Client told to backoff - sleeping for %r seconds", cls._backoff) - cls._connection.sleep(cls._backoff) - cls._backoff = 0 - cls._connection.work() + if uamqp_transport: + self.message_handler.work() + self._waiting_messages = 0 + self._pending_messages = self._filter_pending() + if self._backoff and not self._waiting_messages: + _logger.info("Client told to backoff - sleeping for %r seconds", self._backoff) + self._connection.sleep(self._backoff) + self._backoff = 0 + self._connection.work() + else: + try: + # TODO: update for uamqp + self._link.update_pending_deliveries() + self._connection.listen(wait=self._socket_timeout, **kwargs) + except ValueError: + self._shutdown = True + return False return True with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: # this one doesn't need to reset the method, as it's hacking the method on the instance sender._handler._client_run = types.MethodType(_hack_amqp_sender_run, sender._handler) @@ -2273,94 +2472,159 @@ def _hack_amqp_sender_run(cls): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_mgmt_operation_timeout(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - def hack_mgmt_execute(self, operation, op_type, message, timeout=0): - start_time = self._counter.get_current_ms() - operation_id = str(uuid.uuid4()) - self._responses[operation_id] = None - - time.sleep(6) # sleep until timeout - while not self._responses[operation_id] and not self.mgmt_error: - if timeout > 0: - now = self._counter.get_current_ms() - if (now - start_time) >= timeout: - raise compat.TimeoutException("Failed to receive mgmt response in {}ms".format(timeout)) - self.connection.work() - if self.mgmt_error: - raise self.mgmt_error - response = self._responses.pop(operation_id) - return response - - original_execute_method = uamqp.mgmt_operation.MgmtOperation.execute - # hack the mgmt method on the class, not on an instance, so it needs reset + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_mgmt_operation_timeout(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp_transport: + def hack_mgmt_execute(self, operation, op_type, message, timeout=0): + start_time = self._counter.get_current_ms() + operation_id = str(uuid.uuid4()) + self._responses[operation_id] = None + + time.sleep(6) # sleep until timeout + while not self._responses[operation_id] and not self.mgmt_error: + if timeout > 0: + now = self._counter.get_current_ms() + if (now - start_time) >= timeout: + raise uamqp.compat.TimeoutException("Failed to receive mgmt response in {}ms".format(timeout)) + self.connection.work() + if self.mgmt_error: + raise self.mgmt_error + response = self._responses.pop(operation_id) + return response + + original_execute_method = uamqp.mgmt_operation.MgmtOperation.execute + # hack the mgmt method on the class, not on an instance, so it needs reset + else: + def hack_mgmt_execute(self, message, operation=None, operation_type=None, timeout=0): + start_time = time.time() + operation_id = str(uuid.uuid4()) + self._responses[operation_id] = None + + time.sleep(6) # sleep until timeout + while not self._responses[operation_id] and not self._mgmt_error: + if timeout and timeout > 0: + now = time.time() + if (now - start_time) >= timeout: + raise TimeoutError("Failed to receive mgmt response in {}ms".format(timeout)) + self._connection.listen() + if self._mgmt_error: + self._responses.pop(operation_id) + raise self._mgmt_error + + response = self._responses.pop(operation_id) + return response + + original_execute_method = management_operation.ManagementOperation.execute + # hack the mgmt method on the class, not on an instance, so it needs reset try: - uamqp.mgmt_operation.MgmtOperation.execute = hack_mgmt_execute + if uamqp_transport: + uamqp.mgmt_operation.MgmtOperation.execute = hack_mgmt_execute + else: + management_operation.ManagementOperation.execute = hack_mgmt_execute with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: with pytest.raises(OperationTimeoutError): scheduled_time_utc = utc_now() + timedelta(seconds=30) sender.schedule_messages(ServiceBusMessage("ServiceBusMessage to be scheduled"), scheduled_time_utc, timeout=5) finally: # must reset the mgmt execute method, otherwise other test cases would use the hacked execute method, leading to timeout error - uamqp.mgmt_operation.MgmtOperation.execute = original_execute_method + if uamqp_transport: + uamqp.mgmt_operation.MgmtOperation.execute = original_execute_method + else: + management_operation.ManagementOperation.execute = original_execute_method @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', lock_duration='PT5S') - def test_queue_operation_negative(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): - def _hack_amqp_message_complete(cls): - raise RuntimeError() + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_operation_negative(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + if uamqp_transport: + def _hack_amqp_message_complete(cls): + raise RuntimeError() + + def _hack_amqp_mgmt_request(cls, message, operation, op_type=None, node=None, callback=None, **kwargs): + raise uamqp.errors.AMQPConnectionError() - def _hack_amqp_mgmt_request(cls, message, operation, op_type=None, node=None, callback=None, **kwargs): - raise uamqp.errors.AMQPConnectionError() + def _hack_sb_receiver_settle_message(self, message, settle_operation, dead_letter_reason=None, dead_letter_error_description=None): + raise uamqp.errors.AMQPError() + else: + def _hack_amqp_message_complete(cls, _, settlement): + if settlement == 'completed': + raise RuntimeError() - def _hack_sb_receiver_settle_message(self, message, settle_operation, dead_letter_reason=None, dead_letter_error_description=None): - raise uamqp.errors.AMQPError() + def _hack_amqp_mgmt_request(cls, message, operation, op_type=None, node=None, callback=None, **kwargs): + raise error.AMQPConnectionError(error.ErrorCondition.ConnectionCloseForced) + + def _hack_sb_receiver_settle_message(self, message, settle_operation, dead_letter_reason=None, dead_letter_error_description=None): + raise error.AMQPException(error.ErrorCondition.ClientError) with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) receiver = sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) - with sender, receiver: - # negative settlement via receiver link - sender.send_messages(ServiceBusMessage("body"), timeout=10) - message = receiver.receive_messages()[0] - message.message.accept = types.MethodType(_hack_amqp_message_complete, message.message) - receiver.complete_message(message) # settle via mgmt link + if not uamqp_transport: + original_settlement = client.ReceiveClient.settle_messages + try: + with sender, receiver: + # negative settlement via receiver link + sender.send_messages(ServiceBusMessage("body"), timeout=10) + message = receiver.receive_messages()[0] + if uamqp_transport: + message._message.accept = types.MethodType(_hack_amqp_message_complete, message._message) + else: + client.ReceiveClient.settle_messages = types.MethodType(_hack_amqp_message_complete, receiver._handler) + receiver.complete_message(message) # settle via mgmt link - origin_amqp_client_mgmt_request_method = uamqp.AMQPClient.mgmt_request - try: - uamqp.AMQPClient.mgmt_request = _hack_amqp_mgmt_request - with pytest.raises(ServiceBusConnectionError): - receiver.peek_messages() - finally: - uamqp.AMQPClient.mgmt_request = origin_amqp_client_mgmt_request_method + if uamqp_transport: + origin_amqp_client_mgmt_request_method = uamqp.AMQPClient.mgmt_request + try: + uamqp.AMQPClient.mgmt_request = _hack_amqp_mgmt_request + with pytest.raises(ServiceBusConnectionError): + receiver.peek_messages() + finally: + uamqp.AMQPClient.mgmt_request = origin_amqp_client_mgmt_request_method + else: + origin_amqp_client_mgmt_request_method = client.AMQPClient.mgmt_request + try: + client.AMQPClient.mgmt_request = _hack_amqp_mgmt_request + with pytest.raises(ServiceBusConnectionError): + receiver.peek_messages() + finally: + client.AMQPClient.mgmt_request = origin_amqp_client_mgmt_request_method - sender.send_messages(ServiceBusMessage("body"), timeout=10) + sender.send_messages(ServiceBusMessage("body"), timeout=10) - message = receiver.receive_messages()[0] + message = receiver.receive_messages()[0] - origin_sb_receiver_settle_message_method = receiver._settle_message - receiver._settle_message = types.MethodType(_hack_sb_receiver_settle_message, receiver) - with pytest.raises(ServiceBusError): - receiver.complete_message(message) + origin_sb_receiver_settle_message_method = receiver._settle_message + receiver._settle_message = types.MethodType(_hack_sb_receiver_settle_message, receiver) + with pytest.raises(ServiceBusError): + receiver.complete_message(message) - receiver._settle_message = origin_sb_receiver_settle_message_method - message = receiver.receive_messages(max_wait_time=6)[0] - receiver.complete_message(message) + receiver._settle_message = origin_sb_receiver_settle_message_method + message = receiver.receive_messages(max_wait_time=6)[0] + receiver.complete_message(message) + finally: + if not uamqp_transport: + client.ReceiveClient.settle_messages = original_settlement + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_send_message_no_body(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_send_message_no_body(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): sb_client = ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string) + servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) with sb_client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(ServiceBusMessage(body=None)) @@ -2390,11 +2654,13 @@ def test_send_message_alternate_body_types(self, **kwargs): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_by_servicebus_client_enum_case_sensitivity(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_by_servicebus_client_enum_case_sensitivity(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): # Note: This test is currently intended to enforce case-sensitivity. If we eventually upgrade to the Fancy Enums being used with new autorest, # we may want to tweak this. with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_receiver(servicebus_queue.name, receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE.value, max_wait_time=5) as receiver: @@ -2414,14 +2680,17 @@ def test_queue_by_servicebus_client_enum_case_sensitivity(self, servicebus_names max_wait_time=5) as receiver: raise Exception("Should not get here, should be case sensitive.") + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_send_dict_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_dict_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -2447,12 +2716,27 @@ def test_queue_send_dict_messages(self, servicebus_namespace_connection_string, received_messages.append(message) assert len(received_messages) == 6 + batch_message = sender.create_message_batch(max_size_in_bytes=73) + for _ in range(2): + try: + batch_message.add_message(message_dict) + except ValueError: + break + sender.send_messages(batch_message) + received_messages = [] + with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: + for message in receiver: + received_messages.append(message) + assert len(received_messages) == 1 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_send_mapping_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_mapping_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): class MappingMessage(DictMixin): def __init__(self, content): self.body = content @@ -2463,7 +2747,7 @@ def __init__(self): self.message_id = 'foo' with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -2499,9 +2783,11 @@ def __init__(self): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest') - def test_queue_send_dict_messages_error_badly_formatted_dicts(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_dict_messages_error_badly_formatted_dicts(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -2527,10 +2813,12 @@ def test_queue_send_dict_messages_error_badly_formatted_dicts(self, servicebus_n @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_send_dict_messages_scheduled(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_dict_messages_scheduled(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: content = "Test scheduled message" message_id = uuid.uuid4() message_id2 = uuid.uuid4() @@ -2588,10 +2876,12 @@ def test_queue_send_dict_messages_scheduled(self, servicebus_namespace_connectio @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_send_dict_messages_scheduled_error_badly_formatted_dicts(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_dict_messages_scheduled_error_badly_formatted_dicts(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: content = "Test scheduled message" message_id = uuid.uuid4() message_id2 = uuid.uuid4() @@ -2606,36 +2896,51 @@ def test_queue_send_dict_messages_scheduled_error_badly_formatted_dicts(self, se with pytest.raises(TypeError): sender.schedule_messages(list_message_dicts, scheduled_enqueue_time) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_receive_iterator_resume_after_link_detach(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_receive_iterator_resume_after_link_detach(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): - def hack_iter_next_mock_error(self): + def hack_iter_next_mock_error(self, wait_time=None): try: self._receive_context.set() self._open() # when trying to receive the second message (execution_times is 1), raising LinkDetach error to mock 10 mins idle timeout if self.execution_times == 1: - from uamqp.errors import LinkDetach - from uamqp.constants import ErrorCodes + # TODO: update uamqp errors to pyamqp + if uamqp_transport: + from uamqp.errors import LinkDetach + from uamqp.constants import ErrorCodes + error = LinkDetach + error_condition = ErrorCodes + else: + from azure.servicebus._pyamqp.error import ErrorCondition, AMQPLinkError + error = AMQPLinkError + error_condition = ErrorCondition + self.execution_times += 1 self.error_raised = True - raise LinkDetach(ErrorCodes.LinkDetachForced) + raise error(error_condition.LinkDetachForced) else: self.execution_times += 1 if not self._message_iter: - self._message_iter = self._handler.receive_messages_iter() - uamqp_message = next(self._message_iter) - message = self._build_message(uamqp_message) + if uamqp_transport: + self._message_iter = self._handler.receive_messages_iter() + else: + self._message_iter = self._handler.receive_messages_iter(timeout=wait_time) + amqp_message = next(self._message_iter) + message = self._build_received_message(amqp_message) return message finally: self._receive_context.clear() with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages( [ServiceBusMessage("test1"), ServiceBusMessage("test2"), ServiceBusMessage("test3")] @@ -2652,15 +2957,18 @@ def hack_iter_next_mock_error(self): assert receiver.error_raised assert receiver.execution_times >= 4 # at least 1 failure and 3 successful receiving iterator + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_queue_send_amqp_annotated_message(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_queue_send_amqp_annotated_message(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} prop = {"subject": "sequence"} @@ -2703,8 +3011,12 @@ def test_queue_send_amqp_annotated_message(self, servicebus_namespace_connection dict_message = {"body": content} sb_message = ServiceBusMessage(body=content) message_with_ttl = AmqpAnnotatedMessage(data_body=data_body, header=AmqpMessageHeader(time_to_live=60000)) - uamqp_with_ttl = message_with_ttl._to_outgoing_amqp_message() - assert uamqp_with_ttl.properties.absolute_expiry_time == uamqp_with_ttl.properties.creation_time + uamqp_with_ttl.header.time_to_live + if uamqp_transport: + amqp_transport = UamqpTransport + else: + amqp_transport = PyamqpTransport + amqp_with_ttl = amqp_transport.to_outgoing_amqp_message(message_with_ttl) + assert amqp_with_ttl.properties.absolute_expiry_time == amqp_with_ttl.properties.creation_time + amqp_with_ttl.header.ttl recv_data_msg = recv_sequence_msg = recv_value_msg = normal_msg = 0 with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=10) as receiver: @@ -2761,9 +3073,11 @@ def test_queue_send_amqp_annotated_message(self, servicebus_namespace_connection @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_state_scheduled(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_state_scheduled(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string) as sb_client: + servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) for i in range(10): @@ -2782,9 +3096,11 @@ def test_state_scheduled(self, servicebus_namespace_connection_string, servicebu @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_state_deferred(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_state_deferred(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string) as sb_client: + servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) as sb_client: sender = sb_client.get_queue_sender(servicebus_queue.name) for i in range(10): diff --git a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py index f55b36821b2c6..f3ca826b4a176 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py @@ -10,9 +10,22 @@ import pytest import time from datetime import datetime, timedelta +import hmac +import hashlib +import base64 +try: + from urllib.parse import quote as url_parse_quote +except ImportError: + from urllib import pathname2url as url_parse_quote + +try: + from azure.servicebus._transport._uamqp_transport import UamqpTransport +except ImportError: + pass +from azure.servicebus._transport._pyamqp_transport import PyamqpTransport from azure.common import AzureHttpError, AzureConflictHttpError -from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, AccessToken from azure.core.pipeline.policies import RetryMode from azure.mgmt.servicebus.models import AccessRights from azure.servicebus import ServiceBusClient, ServiceBusSender, ServiceBusReceiver @@ -21,9 +34,10 @@ from azure.servicebus.exceptions import ( ServiceBusError, ServiceBusAuthenticationError, - ServiceBusAuthorizationError + ServiceBusAuthorizationError, + ServiceBusConnectionError ) -from devtools_testutils import AzureMgmtTestCase +from devtools_testutils import AzureMgmtRecordedTestCase from servicebus_preparer import ( CachedServiceBusNamespacePreparer, ServiceBusTopicPreparer, @@ -36,19 +50,25 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX ) +from utilities import uamqp_transport as get_uamqp_transport, ArgPasser +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() -class ServiceBusClientTests(AzureMgmtTestCase): +class TestServiceBusClient(AzureMgmtRecordedTestCase): @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) - def test_sb_client_bad_credentials(self, servicebus_namespace, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_sb_client_bad_credentials(self, uamqp_transport, *, servicebus_namespace=None, servicebus_queue=None, **kwargs): client = ServiceBusClient( fully_qualified_namespace=servicebus_namespace.name + f"{SERVICEBUS_ENDPOINT_SUFFIX}", credential=ServiceBusSharedKeyCredential('invalid', 'invalid'), - logging_enable=False) + logging_enable=False, + uamqp_transport=uamqp_transport + ) with client: with pytest.raises(ServiceBusAuthenticationError): with client.get_queue_sender(servicebus_queue.name) as sender: @@ -56,12 +76,14 @@ def test_sb_client_bad_credentials(self, servicebus_namespace, servicebus_queue, @pytest.mark.liveTest @pytest.mark.live_test_only - def test_sb_client_bad_namespace(self, **kwargs): - + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_sb_client_bad_namespace(self, uamqp_transport, **kwargs): client = ServiceBusClient( fully_qualified_namespace=f"invalid{SERVICEBUS_ENDPOINT_SUFFIX}", credential=ServiceBusSharedKeyCredential('invalid', 'invalid'), - logging_enable=False) + logging_enable=False, + uamqp_transport=uamqp_transport + ) with client: with pytest.raises(ServiceBusError): with client.get_queue_sender('invalidqueue') as sender: @@ -71,9 +93,10 @@ def test_sb_client_bad_namespace(self, **kwargs): @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') - def test_sb_client_bad_entity(self, servicebus_namespace_connection_string, **kwargs): - - client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_sb_client_bad_entity(self, uamqp_transport, *, servicebus_namespace_connection_string=None, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) with client: with pytest.raises(ServiceBusAuthenticationError): @@ -82,7 +105,7 @@ def test_sb_client_bad_entity(self, servicebus_namespace_connection_string, **kw fake_str = f"Endpoint=sb://mock{SERVICEBUS_ENDPOINT_SUFFIX}/;" \ f"SharedAccessKeyName=mock;SharedAccessKey=mock;EntityPath=mockentity" - fake_client = ServiceBusClient.from_connection_string(fake_str) + fake_client = ServiceBusClient.from_connection_string(fake_str, uamqp_transport=uamqp_transport) with pytest.raises(ValueError): fake_client.get_queue_sender('queue') @@ -103,7 +126,7 @@ def test_sb_client_bad_entity(self, servicebus_namespace_connection_string, **kw fake_str = f"Endpoint=sb://mock{SERVICEBUS_ENDPOINT_SUFFIX}/;" \ f"SharedAccessKeyName=mock;SharedAccessKey=mock" - fake_client = ServiceBusClient.from_connection_string(fake_str) + fake_client = ServiceBusClient.from_connection_string(fake_str, uamqp_transport=uamqp_transport) fake_client.get_queue_sender('queue') fake_client.get_queue_receiver('queue') fake_client.get_topic_sender('topic') @@ -115,8 +138,10 @@ def test_sb_client_bad_entity(self, servicebus_namespace_connection_string, **kw @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) @ServiceBusNamespaceAuthorizationRulePreparer(name_prefix='servicebustest', access_rights=[AccessRights.listen]) - def test_sb_client_readonly_credentials(self, servicebus_authorization_rule_connection_string, servicebus_queue, **kwargs): - client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_sb_client_readonly_credentials(self, uamqp_transport, *, servicebus_authorization_rule_connection_string=None, servicebus_queue=None, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string, uamqp_transport=uamqp_transport) with client: with client.get_queue_receiver(servicebus_queue.name) as receiver: @@ -132,8 +157,10 @@ def test_sb_client_readonly_credentials(self, servicebus_authorization_rule_conn @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) @ServiceBusNamespaceAuthorizationRulePreparer(name_prefix='servicebustest', access_rights=[AccessRights.send]) - def test_sb_client_writeonly_credentials(self, servicebus_authorization_rule_connection_string, servicebus_queue, **kwargs): - client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_sb_client_writeonly_credentials(self, uamqp_transport, *, servicebus_authorization_rule_connection_string=None, servicebus_queue=None, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_authorization_rule_connection_string, uamqp_transport=uamqp_transport) with client: with pytest.raises(ServiceBusError): @@ -154,9 +181,14 @@ def test_sb_client_writeonly_credentials(self, servicebus_authorization_rule_con @ServiceBusQueuePreparer(name_prefix='servicebustest_qone', parameter_name='wrong_queue', dead_lettering_on_message_expiration=True) @ServiceBusQueuePreparer(name_prefix='servicebustest_qtwo', dead_lettering_on_message_expiration=True) @ServiceBusQueueAuthorizationRulePreparer(name_prefix='servicebustest_qtwo') - def test_sb_client_incorrect_queue_conn_str(self, servicebus_queue_authorization_rule_connection_string, servicebus_queue, wrong_queue, **kwargs): - - client = ServiceBusClient.from_connection_string(servicebus_queue_authorization_rule_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_sb_client_incorrect_queue_conn_str(self, uamqp_transport, *, servicebus_queue_authorization_rule_connection_string=None, servicebus_queue=None, wrong_queue=None, **kwargs): + if uamqp_transport: + amqp_transport = UamqpTransport + else: + amqp_transport = PyamqpTransport + client = ServiceBusClient.from_connection_string(servicebus_queue_authorization_rule_connection_string, uamqp_transport=uamqp_transport) with client: # Validate that the wrong sender/receiver queues with the right credentials fail. with pytest.raises(ValueError): @@ -179,6 +211,7 @@ def test_sb_client_incorrect_queue_conn_str(self, servicebus_queue_authorization with ServiceBusSender._from_connection_string( servicebus_queue_authorization_rule_connection_string, queue_name=wrong_queue.name, + amqp_transport=amqp_transport ) as sender: sender.send_messages(ServiceBusMessage("test")) @@ -186,18 +219,21 @@ def test_sb_client_incorrect_queue_conn_str(self, servicebus_queue_authorization with ServiceBusReceiver._from_connection_string( servicebus_queue_authorization_rule_connection_string, queue_name=wrong_queue.name, + amqp_transport=amqp_transport ) as receiver: messages = receiver.receive_messages(max_message_count=1, max_wait_time=1) with ServiceBusSender._from_connection_string( servicebus_queue_authorization_rule_connection_string, queue_name=servicebus_queue.name, + amqp_transport=amqp_transport ) as sender: sender.send_messages(ServiceBusMessage("test")) with ServiceBusReceiver._from_connection_string( servicebus_queue_authorization_rule_connection_string, queue_name=servicebus_queue.name, + amqp_transport=amqp_transport ) as receiver: messages = receiver.receive_messages(max_message_count=1, max_wait_time=1) @@ -208,8 +244,10 @@ def test_sb_client_incorrect_queue_conn_str(self, servicebus_queue_authorization @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) @CachedServiceBusTopicPreparer(name_prefix='servicebustest') @CachedServiceBusSubscriptionPreparer(name_prefix='servicebustest') - def test_sb_client_close_spawned_handlers(self, servicebus_namespace_connection_string, servicebus_queue, servicebus_topic, servicebus_subscription, **kwargs): - client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_sb_client_close_spawned_handlers(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string, uamqp_transport=uamqp_transport) client.close() @@ -295,7 +333,11 @@ def test_sb_client_close_spawned_handlers(self, servicebus_namespace_connection_ @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() def test_client_sas_credential(self, + uamqp_transport, + *, servicebus_queue, servicebus_namespace, servicebus_namespace_key_name, @@ -309,9 +351,9 @@ def test_client_sas_credential(self, token = credential.get_token(auth_uri).token # Finally let's do it with SAS token + conn str - token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) + token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token) - client = ServiceBusClient.from_connection_string(token_conn_str) + client = ServiceBusClient.from_connection_string(token_conn_str, uamqp_transport=uamqp_transport) with client: assert len(client._handlers) == 0 with client.get_queue_sender(servicebus_queue.name) as sender: @@ -321,29 +363,72 @@ def test_client_sas_credential(self, # #token_conn_str_without_se = token_conn_str.split('se=')[0] + token_conn_str.split('se=')[1].split('&')[1] # - #client = ServiceBusClient.from_connection_string(token_conn_str_without_se) + #client = ServiceBusClient.from_connection_string(token_conn_str_without_se, uamqp_transport=uamqp_transport) #with client: # assert len(client._handlers) == 0 # with client.get_queue_sender(servicebus_queue.name) as sender: # sender.send_messages(ServiceBusMessage("foo")) + def generate_sas_token(uri, sas_name, sas_value, token_ttl): + """Performs the signing and encoding needed to generate a sas token from a sas key.""" + sas = sas_value.encode('utf-8') + expiry = str(int(time.time() + token_ttl)) + string_to_sign = (uri + '\n' + expiry).encode('utf-8') + signed_hmac_sha256 = hmac.HMAC(sas, string_to_sign, hashlib.sha256) + signature = url_parse_quote(base64.b64encode(signed_hmac_sha256.digest())) + return 'SharedAccessSignature sr={}&sig={}&se={}&skn={}'.format(uri, signature, expiry, sas_name) + + class CustomizedSASCredential(object): + def __init__(self, token, expiry): + """ + :param str token: The token string + :param float expiry: The epoch timestamp + """ + self.token = token + self.expiry = expiry + self.token_type = b"servicebus.windows.net:sastoken" + + def get_token(self, *scopes, **kwargs): + """ + This method is automatically called when token is about to expire. + """ + return AccessToken(self.token, self.expiry) + + token_ttl = 5 # seconds + sas_token = generate_sas_token( + auth_uri, servicebus_namespace_key_name, servicebus_namespace_primary_key, token_ttl + ) + credential=CustomizedSASCredential(sas_token, time.time() + token_ttl) + + with ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) as client: + sender = client.get_queue_sender(queue_name=servicebus_queue.name) + time.sleep(10) + with pytest.raises(ServiceBusAuthenticationError): + with sender: + message = ServiceBusMessage("Single Message") + sender.send_messages(message) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() def test_client_credential(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): # This should "just work" to validate known-good. credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) with client: assert len(client._handlers) == 0 with client.get_queue_sender(servicebus_queue.name) as sender: @@ -351,7 +436,7 @@ def test_client_credential(self, hostname = f"sb://{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) with client: assert len(client._handlers) == 0 with client.get_queue_sender(servicebus_queue.name) as sender: @@ -359,7 +444,7 @@ def test_client_credential(self, hostname = f"https://{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) with client: assert len(client._handlers) == 0 with client.get_queue_sender(servicebus_queue.name) as sender: @@ -370,23 +455,27 @@ def test_client_credential(self, @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - def test_client_azure_named_key_credential(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_client_azure_sas_credential(self, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): # This should "just work" to validate known-good. credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" auth_uri = "sb://{}/{}".format(hostname, servicebus_queue.name) - token = credential.get_token(auth_uri).token.decode() + token = credential.get_token(auth_uri).token # Finally let's do it with AzureSasCredential credential = AzureSasCredential(token) - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) with client: assert len(client._handlers) == 0 @@ -395,17 +484,21 @@ def test_client_azure_named_key_credential(self, @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - def test_client_azure_named_key_credential(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_azure_named_key_credential(self, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" credential = AzureNamedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) - client = ServiceBusClient(hostname, credential) + client = ServiceBusClient(hostname, credential, uamqp_transport=uamqp_transport) with client: with client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(ServiceBusMessage("foo")) @@ -422,11 +515,14 @@ def test_client_azure_named_key_credential(self, with client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(ServiceBusMessage("foo")) - def test_backoff_fixed_retry(self): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_backoff_fixed_retry(self, uamqp_transport): + client = ServiceBusClient( 'fake.host.com', 'fake_eh', - retry_mode='fixed' + retry_mode='fixed', + uamqp_transport=uamqp_transport ) # queue sender sender = client.get_queue_sender('fake_name') @@ -466,7 +562,8 @@ def test_backoff_fixed_retry(self): client = ServiceBusClient( 'fake.host.com', 'fake_eh', - retry_mode=RetryMode.Fixed + retry_mode=RetryMode.Fixed, + uamqp_transport=uamqp_transport ) # queue sender sender = client.get_queue_sender('fake_name') @@ -479,79 +576,87 @@ def test_backoff_fixed_retry(self): # check that fixed is less than 'exp' assert sleep_time_fixed < backoff * (2 ** 1) - def test_custom_client_id_queue_sender(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_custom_client_id_queue_sender(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" custom_id = "my_custom_id" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: queue_sender = servicebus_client.get_queue_sender(queue_name=queue_name, client_identifier=custom_id) assert queue_sender.client_identifier is not None assert queue_sender.client_identifier == custom_id - def test_default_client_id_queue_sender(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_default_client_id_queue_sender(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: queue_sender = servicebus_client.get_queue_sender(queue_name=queue_name) assert queue_sender.client_identifier is not None assert "SBSender" in queue_sender.client_identifier - def test_custom_client_id_queue_receiver(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_custom_client_id_queue_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" custom_id = "my_custom_id" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: queue_receiver = servicebus_client.get_queue_receiver(queue_name=queue_name, client_identifier=custom_id) assert queue_receiver.client_identifier is not None assert queue_receiver.client_identifier == custom_id - def test_default_client_id_queue_receiver(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_default_client_id_queue_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' queue_name = "queue_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: queue_receiver = servicebus_client.get_queue_receiver(queue_name=queue_name) assert queue_receiver.client_identifier is not None assert "SBReceiver" in queue_receiver.client_identifier - def test_custom_client_id_topic_sender(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_custom_client_id_topic_sender(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' custom_id = "my_custom_id" topic_name = "topic_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: topic_sender = servicebus_client.get_topic_sender(topic_name=topic_name, client_identifier=custom_id) assert topic_sender.client_identifier is not None assert topic_sender.client_identifier == custom_id - def test_default_client_id_topic_sender(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_default_client_id_topic_sender(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' topic_name = "topic_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: topic_sender = servicebus_client.get_topic_sender(topic_name=topic_name) assert topic_sender.client_identifier is not None assert "SBSender" in topic_sender.client_identifier - def test_default_client_id_subscription_receiver(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_default_client_id_subscription_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' topic_name = "topic_name" sub_name = "sub_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: subscription_receiver = servicebus_client.get_subscription_receiver(topic_name, sub_name) assert subscription_receiver.client_identifier is not None assert "SBReceiver" in subscription_receiver.client_identifier - def test_custom_client_id_subscription_receiver(self, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + def test_custom_client_id_subscription_receiver(self, uamqp_transport, **kwargs): servicebus_connection_str = f'Endpoint=sb://resourcename{SERVICEBUS_ENDPOINT_SUFFIX}/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' custom_id = "my_custom_id" topic_name = "topic_name" sub_name = "sub_name" - servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str) + servicebus_client = ServiceBusClient.from_connection_string(conn_str=servicebus_connection_str, uamqp_transport=uamqp_transport) with servicebus_client: subscription_receiver = servicebus_client.get_subscription_receiver(topic_name, sub_name, client_identifier=custom_id) assert subscription_receiver.client_identifier is not None @@ -562,18 +667,50 @@ def test_custom_client_id_subscription_receiver(self, **kwargs): @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest') - def test_connection_verify_exception(self, - servicebus_queue, - servicebus_namespace, - servicebus_namespace_key_name, - servicebus_namespace_primary_key, - servicebus_namespace_connection_string, + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_custom_endpoint_connection_verify_exception(self, + uamqp_transport, + *, + servicebus_queue=None, + servicebus_namespace=None, + servicebus_namespace_key_name=None, + servicebus_namespace_primary_key=None, + servicebus_namespace_connection_string=None, **kwargs): hostname = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" credential = AzureNamedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) - client = ServiceBusClient(hostname, credential, connection_verify="cacert.pem") + client = ServiceBusClient(hostname, credential, connection_verify="cacert.pem", uamqp_transport=uamqp_transport) with client: with pytest.raises(ServiceBusError): with client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(ServiceBusMessage("foo")) + + # Skipping on OSX uamqp - it's raising an Authentication/TimeoutError + if not uamqp_transport or not sys.platform.startswith('darwin'): + fake_addr = "fakeaddress.com:1111" + client = ServiceBusClient( + hostname, + credential, + custom_endpoint_address=fake_addr, + retry_total=0, + uamqp_transport=uamqp_transport + ) + with client: + with pytest.raises(ServiceBusConnectionError): + with client.get_queue_sender(servicebus_queue.name) as sender: + sender.send_messages(ServiceBusMessage("foo")) + + client = ServiceBusClient( + hostname, + credential, + custom_endpoint_address=fake_addr, + connection_verify="cacert.pem", + retry_total=0, + uamqp_transport=uamqp_transport, + ) + with client: + with pytest.raises(ServiceBusError): + with client.get_queue_sender(servicebus_queue.name) as sender: + sender.send_messages(ServiceBusMessage("foo")) diff --git a/sdk/servicebus/azure-servicebus/tests/test_sessions.py b/sdk/servicebus/azure-servicebus/tests/test_sessions.py index 43d8fb0e02a27..8ddc1786af367 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sessions.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sessions.py @@ -11,6 +11,7 @@ import pytest import time import uuid +import pickle from datetime import timedelta from azure.servicebus import ( @@ -33,7 +34,7 @@ AutoLockRenewTimeout ) -from devtools_testutils import AzureMgmtTestCase +from devtools_testutils import AzureMgmtRecordedTestCase from servicebus_preparer import ( CachedServiceBusNamespacePreparer, CachedServiceBusQueuePreparer, @@ -42,20 +43,24 @@ ServiceBusSubscriptionPreparer, CachedServiceBusResourceGroupPreparer ) -from utilities import get_logger, print_message, sleep_until_expired +from utilities import get_logger, print_message, sleep_until_expired, uamqp_transport as get_uamqp_transport, ArgPasser +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() _logger = get_logger(logging.DEBUG) -class ServiceBusSessionTests(AzureMgmtTestCase): +class TestServiceBusSession(AzureMgmtRecordedTestCase): + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_session_client_conn_str_receive_handler_peeklock(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_session_client_conn_str_receive_handler_peeklock(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) sender = sb_client.get_queue_sender(servicebus_queue.name) @@ -171,14 +176,21 @@ def test_session_by_session_client_conn_str_receive_handler_peeklock(self, servi assert received_cnt_dic['0'] == 2 and received_cnt_dic['1'] == 2 and received_cnt_dic['2'] == 2 assert count == 6 + with pytest.raises(ServiceBusError): + receiver = sb_client.get_queue_receiver(servicebus_queue.name, session_id=1) + with receiver: + pass + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True, lock_duration='PT5S') - def test_session_by_queue_client_conn_str_receive_handler_receiveanddelete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_queue_client_conn_str_receive_handler_receiveanddelete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -208,14 +220,17 @@ def test_session_by_queue_client_conn_str_receive_handler_receiveanddelete(self, messages.append(message) assert len(messages) == 0 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_session_client_conn_str_receive_handler_with_stop(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_session_client_conn_str_receive_handler_with_stop(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -254,9 +269,11 @@ def test_session_by_session_client_conn_str_receive_handler_with_stop(self, serv @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_session_client_conn_str_receive_handler_with_no_session(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_session_client_conn_str_receive_handler_with_no_session(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False, retry_total=1) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport, retry_total=1) as sb_client: with pytest.raises(OperationTimeoutError): with sb_client.get_queue_receiver(servicebus_queue.name, session_id=NEXT_AVAILABLE_SESSION, @@ -269,10 +286,12 @@ def test_session_by_session_client_conn_str_receive_handler_with_no_session(self @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_connection_failure_is_idempotent(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_connection_failure_is_idempotent(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): #Technically this validates for all senders/receivers, not just session, but since it uses session to generate a recoverable failure, putting it in here. with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: # First let's just try the naive failure cases. receiver = sb_client.get_queue_receiver("THIS_IS_WRONG_ON_PURPOSE") @@ -302,14 +321,17 @@ def test_session_connection_failure_is_idempotent(self, servicebus_namespace_con messages.append(message) assert len(messages) == 1 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_session_client_conn_str_receive_handler_with_inactive_session(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_session_client_conn_str_receive_handler_with_inactive_session(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) messages = [] @@ -323,14 +345,17 @@ def test_session_by_session_client_conn_str_receive_handler_with_inactive_sessio assert session._running assert len(messages) == 0 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_complete(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: deferred_messages = [] @@ -365,14 +390,17 @@ def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_recei receiver.renew_message_lock(message) receiver.complete_message(message) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: deferred_messages = [] @@ -416,14 +444,17 @@ def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_recei receiver.complete_message(message) assert count == 10 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_receiver_deletemode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: deferred_messages = [] @@ -454,14 +485,17 @@ def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_recei with pytest.raises(ServiceBusError): deferred = receiver.receive_deferred_messages(deferred_messages) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: deferred_messages = [] session_id = str(uuid.uuid4()) @@ -485,15 +519,18 @@ def test_session_by_servicebus_client_iter_messages_with_retrieve_deferred_clien with pytest.raises(MessageAlreadySettled): receiver.complete_message(message) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_receive_with_retrieve_deadletter(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_receive_with_retrieve_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_receiver(servicebus_queue.name, @@ -535,9 +572,11 @@ def test_session_by_servicebus_client_receive_with_retrieve_deadletter(self, ser @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_browse_messages_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_browse_messages_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(5): @@ -572,10 +611,12 @@ def test_session_by_servicebus_client_browse_messages_client(self, servicebus_na @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_browse_messages_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_browse_messages_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5, session_id=session_id) as receiver: @@ -597,9 +638,11 @@ def test_session_by_servicebus_client_browse_messages_with_receiver(self, servic @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_renew_client_locks(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_renew_client_locks(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) messages = [] @@ -638,16 +681,19 @@ def test_session_by_servicebus_client_renew_client_locks(self, servicebus_namesp with pytest.raises(SessionLockLostError): receiver.complete_message(messages[2]) + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True, lock_duration='PT10S') - def test_session_by_conn_str_receive_handler_with_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_conn_str_receive_handler_with_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): session_id = str(uuid.uuid4()) with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -713,16 +759,19 @@ def lock_lost_callback(renewable, error): renewer.close() assert len(messages) == 2 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True, lock_duration='PT10S') - def test_session_by_conn_str_receive_handler_with_auto_autolockrenew(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_conn_str_receive_handler_with_auto_autolockrenew(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): session_id = str(uuid.uuid4()) with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=True) as sb_client: + servicebus_namespace_connection_string, logging_enable=True, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: for i in range(10): @@ -816,10 +865,12 @@ def lock_lost_callback(renewable, error): @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_receiver_partially_invalid_autolockrenew_mode(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_receiver_partially_invalid_autolockrenew_mode(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): session_id = str(uuid.uuid4()) with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(ServiceBusMessage("test_message", session_id=session_id)) @@ -842,10 +893,12 @@ def should_not_run(*args, **kwargs): @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_message_connection_closed(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_message_connection_closed(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) @@ -866,10 +919,12 @@ def test_session_message_connection_closed(self, servicebus_namespace_connection @CachedServiceBusResourceGroupPreparer() @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True, lock_duration='PT5S') - def test_session_message_expiry(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_message_expiry(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -905,10 +960,12 @@ def test_session_message_expiry(self, servicebus_namespace_connection_string, se @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_schedule_message(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_schedule_message(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) @@ -941,10 +998,12 @@ def test_session_schedule_message(self, servicebus_namespace_connection_string, @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_schedule_multiple_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_schedule_multiple_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) @@ -982,10 +1041,12 @@ def test_session_schedule_multiple_messages(self, servicebus_namespace_connectio @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_cancel_scheduled_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_cancel_scheduled_messages(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) enqueue_time = (utc_now() + timedelta(minutes=2)).replace(microsecond=0) @@ -1006,16 +1067,18 @@ def test_session_cancel_scheduled_messages(self, servicebus_namespace_connection count += 1 assert len(messages) == 0 - + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_get_set_state_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_get_set_state_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -1023,20 +1086,34 @@ def test_session_get_set_state_with_receiver(self, servicebus_namespace_connecti message = ServiceBusMessage("Handler message no. {}".format(i), session_id=session_id) sender.send_messages(message) - with sb_client.get_queue_receiver(servicebus_queue.name, session_id=session_id, max_wait_time=5) as receiver: - assert receiver.session.get_state(timeout=5) == None - receiver.session.set_state("first_state", timeout=5) + with sb_client.get_queue_receiver(servicebus_queue.name, session_id=session_id, max_wait_time=5) as session: + assert session.session.get_state(timeout=5) == None + session.session.set_state("first_state", timeout=5) count = 0 - for m in receiver: + for m in session: assert m.session_id == session_id count += 1 - state = receiver.session.get_state() + state = session.session.get_state() assert state == b'first_state' - receiver.session.set_state(None, timeout=5) - state = receiver.session.get_state() - assert not state assert count == 3 + session_id = str(uuid.uuid4()) + with sb_client.get_queue_sender(servicebus_queue.name) as sender: + for i in range(1): + message = ServiceBusMessage("Handler message no. {}".format(i), session_id=session_id) + sender.send_messages(message) + + with sb_client.get_queue_receiver(servicebus_queue.name, session_id=session_id, max_wait_time=5) as session: + assert session.session.get_state(timeout=5) == None + session.session.set_state(None, timeout=5) + count = 0 + for m in session: + assert m.session_id == session_id + count += 1 + state = session.session.get_state() + assert state == None + assert count == 1 + @pytest.mark.skip(reason="Needs list sessions") @pytest.mark.liveTest @@ -1044,10 +1121,12 @@ def test_session_get_set_state_with_receiver(self, servicebus_namespace_connecti @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_list_sessions_with_receiver(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_list_sessions_with_receiver(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sessions = [] start_time = utc_now() @@ -1075,10 +1154,12 @@ def test_session_by_servicebus_client_list_sessions_with_receiver(self, serviceb @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_list_sessions_with_client(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_list_sessions_with_client(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sessions = [] start_time = utc_now() @@ -1105,7 +1186,9 @@ def test_session_by_servicebus_client_list_sessions_with_client(self, servicebus @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_servicebus_client_session_pool(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_servicebus_client_session_pool(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): messages = [] errors = [] concurrent_receivers = 5 @@ -1125,7 +1208,7 @@ def message_processing(sb_client): raise with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: sessions = [str(uuid.uuid4()) for i in range(concurrent_receivers)] @@ -1144,14 +1227,17 @@ def message_processing(sb_client): assert not errors assert len(messages) == 100 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_by_session_client_conn_str_receive_handler_peeklock_abandon(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_by_session_client_conn_str_receive_handler_peeklock_abandon(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: session_id = str(uuid.uuid4()) with sb_client.get_queue_sender(servicebus_queue.name) as sender: @@ -1169,16 +1255,19 @@ def test_session_by_session_client_conn_str_receive_handler_peeklock_abandon(sel if next_message.sequence_number == 1: return + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest', requires_session=True) - def test_session_basic_topic_subscription_send_and_receive(self, servicebus_namespace_connection_string, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_basic_topic_subscription_send_and_receive(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, uamqp_transport=uamqp_transport ) as sb_client: with sb_client.get_topic_sender(topic_name=servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message", session_id='test_session') @@ -1201,9 +1290,11 @@ def test_session_basic_topic_subscription_send_and_receive(self, servicebus_name @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) - def test_session_non_session_send_to_session_queue_should_fail(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_non_session_send_to_session_queue_should_fail(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): with ServiceBusClient.from_connection_string( - servicebus_namespace_connection_string, logging_enable=False) as sb_client: + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: with sb_client.get_queue_sender(servicebus_queue.name) as sender: message = ServiceBusMessage("This should be an invalid non session message") diff --git a/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py b/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py index 0ad6a50ad73f1..af6613cf1c795 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py +++ b/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py @@ -16,7 +16,7 @@ from azure.servicebus.exceptions import ServiceBusError from azure.servicebus._common.constants import ServiceBusSubQueue -from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer from servicebus_preparer import ( CachedServiceBusNamespacePreparer, CachedServiceBusTopicPreparer, @@ -26,23 +26,29 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX ) -from utilities import get_logger, print_message +from utilities import get_logger, print_message, uamqp_transport as get_uamqp_transport, ArgPasser +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() + _logger = get_logger(logging.DEBUG) -class ServiceBusSubscriptionTests(AzureMgmtTestCase): +class TestServiceBusSubscription(AzureMgmtRecordedTestCase): + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest') - def test_subscription_by_subscription_client_conn_str_receive_basic(self, servicebus_namespace_connection_string, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_subscription_by_subscription_client_conn_str_receive_basic(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: with sb_client.get_topic_sender(topic_name=servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message") @@ -64,22 +70,22 @@ def test_subscription_by_subscription_client_conn_str_receive_basic(self, servic with pytest.raises(ValueError): receiver.receive_messages(max_wait_time=-1) - with pytest.raises(ValueError): - receiver._get_streaming_message_iter(max_wait_time=0) - count = 0 for message in receiver: count += 1 receiver.complete_message(message) assert count == 1 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest') - def test_subscription_by_sas_token_credential_conn_str_send_basic(self, servicebus_namespace, servicebus_namespace_key_name, servicebus_namespace_primary_key, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_subscription_by_sas_token_credential_conn_str_send_basic(self, uamqp_transport, *, servicebus_namespace=None, servicebus_namespace_key_name=None, servicebus_namespace_primary_key=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" with ServiceBusClient( fully_qualified_namespace=fully_qualified_namespace, @@ -87,7 +93,8 @@ def test_subscription_by_sas_token_credential_conn_str_send_basic(self, serviceb policy=servicebus_namespace_key_name, key=servicebus_namespace_primary_key ), - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: with sb_client.get_topic_sender(topic_name=servicebus_topic.name) as sender: @@ -112,13 +119,17 @@ def test_subscription_by_sas_token_credential_conn_str_send_basic(self, serviceb @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest') - def test_subscription_by_servicebus_client_list_subscriptions(self, servicebus_namespace, servicebus_namespace_key_name, servicebus_namespace_primary_key, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_subscription_by_servicebus_client_list_subscriptions(self, uamqp_transport, *, servicebus_namespace=None, servicebus_namespace_key_name=None, servicebus_namespace_primary_key=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): client = ServiceBusClient( service_namespace=servicebus_namespace.name, shared_access_key_name=servicebus_namespace_key_name, shared_access_key_value=servicebus_namespace_primary_key, - debug=False) + logging_enable=False, + uamqp_transport=uamqp_transport + ) subs = client.list_subscriptions(servicebus_topic.name) assert len(subs) >= 1 @@ -126,17 +137,21 @@ def test_subscription_by_servicebus_client_list_subscriptions(self, servicebus_n assert subs[0].name == servicebus_subscription.name assert subs[0].topic_name == servicebus_topic.name + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') @ServiceBusSubscriptionPreparer(name_prefix='servicebustest') - def test_subscription_by_servicebus_client_receive_batch_with_deadletter(self, servicebus_namespace_connection_string, servicebus_topic, servicebus_subscription, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_subscription_by_servicebus_client_receive_batch_with_deadletter(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_topic=None, servicebus_subscription=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: with sb_client.get_subscription_receiver( diff --git a/sdk/servicebus/azure-servicebus/tests/test_topic.py b/sdk/servicebus/azure-servicebus/tests/test_topic.py index a662a25b03f22..a9989999f76c9 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_topic.py +++ b/sdk/servicebus/azure-servicebus/tests/test_topic.py @@ -11,7 +11,7 @@ import time from datetime import datetime, timedelta -from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer from azure.servicebus import ServiceBusClient from azure.servicebus._base_handler import ServiceBusSharedKeyCredential @@ -24,22 +24,27 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX ) -from utilities import get_logger, print_message +from utilities import get_logger, print_message, uamqp_transport as get_uamqp_transport, ArgPasser +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() + _logger = get_logger(logging.DEBUG) -class ServiceBusTopicsTests(AzureMgmtTestCase): +class TestServiceBusTopics(AzureMgmtRecordedTestCase): @pytest.mark.liveTest @pytest.mark.live_test_only @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusTopicPreparer(name_prefix='servicebustest') - def test_topic_by_servicebus_client_conn_str_send_basic(self, servicebus_namespace_connection_string, servicebus_topic, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_topic_by_servicebus_client_conn_str_send_basic(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_topic=None, **kwargs): with ServiceBusClient.from_connection_string( servicebus_namespace_connection_string, - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: with sb_client.get_topic_sender(servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message") @@ -50,7 +55,9 @@ def test_topic_by_servicebus_client_conn_str_send_basic(self, servicebus_namespa @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') @CachedServiceBusTopicPreparer(name_prefix='servicebustest') - def test_topic_by_sas_token_credential_conn_str_send_basic(self, servicebus_namespace, servicebus_namespace_key_name, servicebus_namespace_primary_key, servicebus_topic, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_topic_by_sas_token_credential_conn_str_send_basic(self, uamqp_transport, *, servicebus_namespace=None, servicebus_namespace_key_name=None, servicebus_namespace_primary_key=None, servicebus_topic=None, **kwargs): fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" with ServiceBusClient( fully_qualified_namespace=fully_qualified_namespace, @@ -58,7 +65,8 @@ def test_topic_by_sas_token_credential_conn_str_send_basic(self, servicebus_name policy=servicebus_namespace_key_name, key=servicebus_namespace_primary_key ), - logging_enable=False + logging_enable=False, + uamqp_transport=uamqp_transport ) as sb_client: with sb_client.get_topic_sender(servicebus_topic.name) as sender: message = ServiceBusMessage(b"Sample topic message") @@ -70,13 +78,17 @@ def test_topic_by_sas_token_credential_conn_str_send_basic(self, servicebus_name @RandomNameResourceGroupPreparer() @ServiceBusNamespacePreparer(name_prefix='servicebustest') @ServiceBusTopicPreparer(name_prefix='servicebustest') - def test_topic_by_servicebus_client_list_topics(self, servicebus_namespace, servicebus_namespace_key_name, servicebus_namespace_primary_key, servicebus_topic, **kwargs): + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_topic_by_servicebus_client_list_topics(self, uamqp_transport, *, servicebus_namespace=None, servicebus_namespace_key_name=None, servicebus_namespace_primary_key=None, servicebus_topic=None, **kwargs): client = ServiceBusClient( service_namespace=servicebus_namespace.name, shared_access_key_name=servicebus_namespace_key_name, shared_access_key_value=servicebus_namespace_primary_key, - debug=False) + logging_enable=False, + uamqp_transport=uamqp_transport + ) topics = client.list_topics() assert len(topics) >= 1 diff --git a/sdk/servicebus/azure-servicebus/tests/unittests/test_errors.py b/sdk/servicebus/azure-servicebus/tests/unittests/test_errors.py new file mode 100644 index 0000000000000..dcc8638d503a9 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/unittests/test_errors.py @@ -0,0 +1,81 @@ +import logging +import pytest + +from azure.servicebus.exceptions import ( + ServiceBusConnectionError, + ServiceBusError +) +try: + from uamqp import errors as uamqp_AMQPErrors, constants as uamqp_AMQPConstants + from azure.servicebus._transport._uamqp_transport import UamqpTransport +except ImportError: + pass +from azure.servicebus._transport._pyamqp_transport import PyamqpTransport +from azure.servicebus._pyamqp import error as AMQPErrors + +from utilities import uamqp_transport as get_uamqp_transport +uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) +def test_link_idle_timeout(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport + amqp_err_cls = uamqp_AMQPErrors.LinkDetach + amqp_err_condition = uamqp_AMQPConstants.ErrorCodes.LinkDetachForced + else: + amqp_transport = PyamqpTransport + amqp_err_cls = AMQPErrors.AMQPLinkError + amqp_err_condition = AMQPErrors.ErrorCondition.LinkDetachForced + amqp_error = amqp_err_cls(amqp_err_condition, description="Details: AmqpMessageConsumer.IdleTimerExpired: Idle timeout: 00:10:00.") + logger = logging.getLogger("testlogger") + sb_error = amqp_transport.create_servicebus_exception(logger, amqp_error) + assert isinstance(sb_error, ServiceBusConnectionError) + assert sb_error._retryable + assert sb_error._shutdown_handler + + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) +def test_unknown_connection_error(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport + amqp_conn_err_cls = uamqp_AMQPErrors.AMQPConnectionError + amqp_err_cls = uamqp_AMQPErrors.AMQPError + amqp_unknown_err_condition = uamqp_AMQPConstants.ErrorCodes.UnknownError + else: + amqp_transport = PyamqpTransport + amqp_conn_err_cls = AMQPErrors.AMQPConnectionError + amqp_err_cls = AMQPErrors.AMQPError + amqp_unknown_err_condition = AMQPErrors.ErrorCondition.UnknownError + logger = logging.getLogger("testlogger") + amqp_error = amqp_conn_err_cls(amqp_unknown_err_condition) + sb_error = amqp_transport.create_servicebus_exception(logger, amqp_error) + assert isinstance(sb_error,ServiceBusConnectionError) + assert sb_error._retryable + assert sb_error._shutdown_handler + + amqp_error = amqp_err_cls(amqp_unknown_err_condition) + sb_error = amqp_transport.create_servicebus_exception(logger, amqp_error) + assert not isinstance(sb_error,ServiceBusConnectionError) + assert isinstance(sb_error,ServiceBusError) + assert not sb_error._retryable + assert sb_error._shutdown_handler + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) +def test_internal_server_error(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport + amqp_err_cls = uamqp_AMQPErrors.LinkDetach + err_condition = uamqp_AMQPConstants.ErrorCodes.InternalServerError + else: + amqp_transport = PyamqpTransport + amqp_err_cls = AMQPErrors.AMQPLinkError + err_condition = AMQPErrors.ErrorCondition.InternalError + logger = logging.getLogger("testlogger") + amqp_error = amqp_err_cls( + description="The service was unable to process the request; please retry the operation.", + condition=err_condition + ) + sb_error = amqp_transport.create_servicebus_exception(logger, amqp_error) + assert isinstance(sb_error, ServiceBusError) + assert sb_error._retryable + assert sb_error._shutdown_handler diff --git a/sdk/servicebus/azure-servicebus/tests/utilities.py b/sdk/servicebus/azure-servicebus/tests/utilities.py index cd0ca6cb958bf..ec066eaf59caa 100644 --- a/sdk/servicebus/azure-servicebus/tests/utilities.py +++ b/sdk/servicebus/azure-servicebus/tests/utilities.py @@ -7,6 +7,11 @@ import logging import sys import time +try: + import uamqp + uamqp_available = True +except (ModuleNotFoundError, ImportError): + uamqp_available = False from azure.servicebus._common.utils import utc_now def _get_default_handler(): @@ -24,26 +29,52 @@ def _build_logger(name, level): # Note: This was the initial generic logger entry point, kept to allow us to # move to more fine-grained logging controls incrementally. -def get_logger(level, uamqp_level=logging.INFO): - _build_logger("uamqp", uamqp_level) +def get_logger(level, amqp_level=logging.INFO): + _build_logger("azure.servicebus._pyamqp", amqp_level) + _build_logger("uamqp", amqp_level) return _build_logger("azure", level) def print_message(_logger, message): - _logger.info("Receiving: {}".format(message)) - _logger.debug("Time to live: {}".format(message.time_to_live)) - _logger.debug("Sequence number: {}".format(message.sequence_number)) - _logger.debug("Enqueue Sequence numger: {}".format(message.enqueued_sequence_number)) - _logger.debug("Partition Key: {}".format(message.partition_key)) - _logger.debug("Application Properties: {}".format(message.application_properties)) - _logger.debug("Delivery count: {}".format(message.delivery_count)) + _logger.info(f"Receiving: {message}") + _logger.debug(f"Time to live: {message.time_to_live}") + _logger.debug(f"Sequence number: {message.sequence_number}") + _logger.debug(f"Enqueue Sequence numger: {message.enqueued_sequence_number}") + _logger.debug(f"Partition Key: {message.partition_key}") + _logger.debug(f"Application Properties: {message.application_properties}") + _logger.debug(f"Delivery count: {message.delivery_count}") try: - _logger.debug("Locked until: {}".format(message.locked_until_utc)) - _logger.debug("Lock Token: {}".format(message.lock_token)) + _logger.debug(f"Locked until: {message.locked_until_utc}") + _logger.debug(f"Lock Token: {message.lock_token}") except (TypeError, AttributeError): pass - _logger.debug("Enqueued time: {}".format(message.enqueued_time_utc)) + _logger.debug(f"Enqueued time: {message.enqueued_time_utc}") def sleep_until_expired(entity): - time.sleep(max(0,(entity.locked_until_utc - utc_now()).total_seconds()+1)) \ No newline at end of file + time.sleep(max(0,(entity.locked_until_utc - utc_now()).total_seconds()+1)) + + +def uamqp_transport(use_uamqp=uamqp_available, use_pyamqp=True): + uamqp_transport_params = [] + uamqp_transport_ids = [] + if use_uamqp: + uamqp_transport_params.append(True) + uamqp_transport_ids.append("uamqp") + if use_pyamqp: + uamqp_transport_params.append(False) + uamqp_transport_ids.append("pyamqp") + return uamqp_transport_params, uamqp_transport_ids + +class ArgPasser: + def __call__(self, fn): + def _preparer(test_class, uamqp_transport, **kwargs): + fn(test_class, uamqp_transport=uamqp_transport, **kwargs) + return _preparer + +class ArgPasserAsync: + def __call__(self, fn): + async def _preparer(test_class, uamqp_transport, **kwargs): + await fn(test_class, uamqp_transport=uamqp_transport, **kwargs) + return _preparer + \ No newline at end of file diff --git a/sdk/servicebus/perf-resources.bicep b/sdk/servicebus/perf-resources.bicep new file mode 100644 index 0000000000000..e17e1daef2ab0 --- /dev/null +++ b/sdk/servicebus/perf-resources.bicep @@ -0,0 +1,68 @@ +param baseName string = resourceGroup().name +param location string = resourceGroup().location + +var serviceBusNamespaceName = 'sb-${baseName}' +var serviceBusQueueName = '${serviceBusNamespaceName}-queue' +var serviceBusTopicName = '${serviceBusNamespaceName}-topic' +var serviceBusSubscriptionName = '${serviceBusNamespaceName}-subscription' +var defaultSASKeyName = 'RootManageSharedAccessKey' +var sbVersion = '2017-04-01' + + + +resource serviceBusNamespace 'Microsoft.ServiceBus/namespaces@2017-04-01' = { + name: serviceBusNamespaceName + location: location + sku: { + name: 'Standard' + } + properties: {} +} + +resource serviceBusQueue 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = { + parent: serviceBusNamespace + name: serviceBusQueueName + properties: { + lockDuration: 'PT5M' + maxSizeInMegabytes: 4096 + requiresDuplicateDetection: false + requiresSession: false + defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S' + deadLetteringOnMessageExpiration: false + duplicateDetectionHistoryTimeWindow: 'PT10M' + maxDeliveryCount: 10 + autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S' + enablePartitioning: false + enableExpress: false + } +} + +resource serviceBusTopic 'Microsoft.ServiceBus/namespaces/topics@2017-04-01' = { + parent: serviceBusNamespace + name: serviceBusTopicName + properties: { + autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S' + defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S' + duplicateDetectionHistoryTimeWindow: 'PT10M' + enableBatchedOperations: true + enableExpress: false + enablePartitioning: false + maxSizeInMegabytes: 4096 + requiresDuplicateDetection: false + status: 'Active' + supportOrdering: true + } +} + +resource serviceBusSubscription 'Microsoft.ServiceBus/namespaces/topics/subscriptions@2017-04-01' = { + parent: serviceBusTopic + name: serviceBusSubscriptionName + properties: { + } +} + +var authRuleResourceId = resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', serviceBusNamespace.name, defaultSASKeyName) +output AZURE_SERVICEBUS_CONNECTION_STRING string = listkeys(authRuleResourceId, sbVersion).primaryConnectionString +output AZURE_SERVICEBUS_QUEUE_NAME string = serviceBusQueue.name +output AZURE_SERVICEBUS_TOPIC_NAME string = serviceBusTopic.name +output AZURE_SERVICEBUS_SUBSCRIPTION_NAME string = serviceBusSubscription.name diff --git a/sdk/servicebus/perf-tests.yml b/sdk/servicebus/perf-tests.yml new file mode 100644 index 0000000000000..c4b3a02621fbe --- /dev/null +++ b/sdk/servicebus/perf-tests.yml @@ -0,0 +1,44 @@ +Service: servicebus + +Project: sdk/servicebus/azure-servicebus + +PrimaryPackage: azure-servicebus + +PackageVersions: +- azure-core: 1.26.3 + azure-servicebus: 7.8.3 +- azure-core: source + azure-servicebus: source + +Tests: +- Test: send-queue-message-batch + Class: SendQueueMessageBatchTest + Arguments: + - --message-size 1024 --batch-size 100 + - --message-size 1024 --batch-size 100 --uamqp-transport + - --message-size 1024 --batch-size 100 --transport-type 1 + - --message-size 1024 --batch-size 100 --transport-type 1 --uamqp-transport + +- Test: receive-queue-message-batch + Class: ReceiveQueueMessageBatchTest + Arguments: + - --message-size 2000 --num-messages 50 --preload 10000 + - --message-size 2000 --num-messages 50 --preload 10000 --uamqp-transport + - --message-size 2000 --num-messages 50 --preload 10000 --transport-type 1 + - --message-size 2000 --num-messages 50 --preload 10000 --transport-type 1 --uamqp-transport + +- Test: send-subscription-message-batch + Class: SendTopicMessageBatchTest + Arguments: + - --message-size 1024 --batch-size 100 + - --message-size 1024 --batch-size 100 --uamqp-transport + - --message-size 1024 --batch-size 100 --transport-type 1 + - --message-size 1024 --batch-size 100 --transport-type 1 --uamqp-transport + +- Test: receive-subscription-message-batch + Class: ReceiveSubscriptionMessageBatchTest + Arguments: + - --message-size 2000 --num-messages 50 --preload 10000 + - --message-size 2000 --num-messages 50 --preload 10000 --uamqp-transport + - --message-size 2000 --num-messages 50 --preload 10000 --transport-type 1 + - --message-size 2000 --num-messages 50 --preload 10000 --transport-type 1 --uamqp-transport \ No newline at end of file diff --git a/sdk/servicebus/perf.yml b/sdk/servicebus/perf.yml new file mode 100644 index 0000000000000..e02887e68ca9e --- /dev/null +++ b/sdk/servicebus/perf.yml @@ -0,0 +1,36 @@ +parameters: +- name: LanguageVersion + displayName: LanguageVersion (3.7, 3.8, 3.9, 3.10, 3.11) + type: string + default: '3.11' +- name: PackageVersions + displayName: PackageVersions (regex of package versions to run) + type: string + default: '7|source' +- name: Tests + displayName: Tests (regex of tests to run) + type: string + default: '^(send-queue-message-batch|send-subscription-message-batch|receive-queue-message-batch|receive-subscription-message-batch)$' +- name: Arguments + displayName: Arguments (regex of arguments to run) + type: string + default: '.*' +- name: Iterations + displayName: Iterations (times to run each test) + type: number + default: '5' +- name: AdditionalArguments + displayName: AdditionalArguments (passed to PerfAutomation) + type: string + default: ' ' + +extends: + template: /eng/pipelines/templates/jobs/perf.yml + parameters: + ServiceDirectory: servicebus + LanguageVersion: ${{ parameters.LanguageVersion }} + PackageVersions: ${{ parameters.PackageVersions }} + Tests: ${{ parameters.Tests }} + Arguments: ${{ parameters.Arguments }} + Iterations: ${{ parameters.Iterations }} + AdditionalArguments: ${{ parameters.AdditionalArguments }} \ No newline at end of file diff --git a/sdk/servicebus/tests.yml b/sdk/servicebus/tests.yml index 8c24ecbf2c8ba..5cd072d4746c5 100644 --- a/sdk/servicebus/tests.yml +++ b/sdk/servicebus/tests.yml @@ -4,7 +4,7 @@ stages: - template: ../../eng/pipelines/templates/stages/archetype-sdk-tests.yml parameters: ServiceDirectory: servicebus - TestTimeoutInMinutes: 300 + TestTimeoutInMinutes: 360 BuildTargetingString: azure-servicebus* EnvVars: AZURE_SUBSCRIPTION_ID: $(SERVICEBUS_SUBSCRIPTION_ID)