Skip to content

Commit

Permalink
[ServiceBus] adding typing/type completeness (#33283)
Browse files Browse the repository at this point in the history
* verifytypes mgmt

* type more models

* 81%

* update dictmixin type

* more typing

* 96.4%

* mypy

* type pyamqp message backcompat

* mgmt mypy errors

* fix more typing

* fix base handler context manager typing

* update swagger to include transforms

* libba

* lint

* fix settler message backcompat

* fix admin corr filter

* thanks libber

* update settler pop

* sync w/ eh

* update mgmt internal model creation

* update if_match in admin

* kashifs comments

* add reasons for casting

* make error description optional

* update min core version

* remove cast in legacymessage

* remove cast in legacymessage
  • Loading branch information
swathipil committed Dec 18, 2023
1 parent 6c67965 commit f8f223d
Show file tree
Hide file tree
Showing 42 changed files with 1,912 additions and 1,059 deletions.
2 changes: 2 additions & 0 deletions sdk/servicebus/azure-servicebus/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

### Other Changes

- Updated minimum `azure-core` version to 1.28.0.

## 7.11.4 (2023-11-13)

### Bugs Fixed
Expand Down
2 changes: 1 addition & 1 deletion sdk/servicebus/azure-servicebus/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/servicebus/azure-servicebus",
"Tag": "python/servicebus/azure-servicebus_81b962314c"
"Tag": "python/servicebus/azure-servicebus_5e25f00bf7"
}
14 changes: 2 additions & 12 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(
self._amqp_transport = kwargs.pop("amqp_transport", PyamqpTransport)

# If the user provided http:// or sb://, let's be polite and strip that.
self.fully_qualified_namespace = strip_protocol_from_uri(
self.fully_qualified_namespace: str = strip_protocol_from_uri(
fully_qualified_namespace.strip()
)
self._entity_name = entity_name
Expand Down Expand Up @@ -330,17 +330,7 @@ def _convert_connection_string_to_kwargs(

return kwargs

def __enter__(self):
if self._shutdown.is_set():
raise ValueError(
"The handler has already been shutdown. Please use ServiceBusClient to "
"create a new instance."
)

self._open_with_retry()
return self

def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
self.close()

def _handle_exception(self, exception: BaseException) -> "ServiceBusError":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@ def __init__(self, **kwargs):
self.auto_reconnect = kwargs.get("auto_reconnect", True)
self.keep_alive = kwargs.get("keep_alive", 30)
self.timeout: float = kwargs.get("timeout", 60)
self.socket_timeout = kwargs.get("socket_timeout", 0.2)
default_socket_timeout = 0.2

if self.http_proxy or self.transport_type.value == TransportType.AmqpOverWebsocket.value:
self.transport_type = TransportType.AmqpOverWebsocket
self.connection_port = DEFAULT_AMQP_WSS_PORT
self.socket_timeout = kwargs.get("socket_timeout", 1)
default_socket_timeout = 1
if amqp_transport.KIND == "pyamqp":
self.hostname += "/$servicebus/websocket"

self.socket_timeout = kwargs.get("socket_timeout") or default_socket_timeout

# custom end point
if self.custom_endpoint_address:
# if the custom_endpoint_address doesn't include the schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from typing import Optional
from ..management._models import DictMixin
from .._base_handler import _parse_conn_str

Expand All @@ -12,64 +13,72 @@ class ServiceBusConnectionStringProperties(DictMixin):
Properties of a connection string.
"""

def __init__(self, **kwargs):
self._fully_qualified_namespace = kwargs.pop("fully_qualified_namespace", None)
self._endpoint = kwargs.pop("endpoint", None)
self._entity_path = kwargs.pop("entity_path", None)
self._shared_access_signature = kwargs.pop("shared_access_signature", None)
self._shared_access_key_name = kwargs.pop("shared_access_key_name", None)
self._shared_access_key = kwargs.pop("shared_access_key", None)
def __init__(
self,
*,
fully_qualified_namespace: str,
endpoint: str,
entity_path: Optional[str] = None,
shared_access_signature: Optional[str] = None,
shared_access_key_name: Optional[str] = None,
shared_access_key: Optional[str] = None
):
self._fully_qualified_namespace = fully_qualified_namespace
self._endpoint = endpoint
self._entity_path = entity_path
self._shared_access_signature = shared_access_signature
self._shared_access_key_name = shared_access_key_name
self._shared_access_key = shared_access_key

@property
def fully_qualified_namespace(self):
def fully_qualified_namespace(self) -> str:
"""The fully qualified host name for the Service Bus namespace.
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:rtype: str
"""
return self._fully_qualified_namespace

@property
def endpoint(self):
def endpoint(self) -> str:
"""The endpoint for the Service Bus resource. In the format sb://<FQDN>/
:rtype: str
"""
return self._endpoint

@property
def entity_path(self):
def entity_path(self) -> Optional[str]:
"""Optional. Represents the name of the queue/topic.
:rtype: str
:rtype: str or None
"""

return self._entity_path

@property
def shared_access_signature(self):
def shared_access_signature(self) -> Optional[str]:
"""
This can be provided instead of the shared_access_key_name and the shared_access_key.
:rtype: str
:rtype: str or None
"""
return self._shared_access_signature

@property
def shared_access_key_name(self):
def shared_access_key_name(self) -> Optional[str]:
"""
The name of the shared_access_key. This must be used along with the shared_access_key.
:rtype: str
:rtype: str or None
"""
return self._shared_access_key_name

@property
def shared_access_key(self):
def shared_access_key(self) -> Optional[str]:
"""
The shared_access_key can be used along with the shared_access_key_name as a credential.
:rtype: str
:rtype: str or None
"""
return self._shared_access_key


def parse_connection_string(conn_str):
# type(str) -> ServiceBusConnectionStringProperties
def parse_connection_string(conn_str: str) -> "ServiceBusConnectionStringProperties":
"""Parse the connection string into a properties bag containing its component parts.
:param conn_str: The connection string that has to be parsed.
Expand All @@ -79,12 +88,11 @@ def parse_connection_string(conn_str):
"""
fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(conn_str, True)[:-1]
endpoint = "sb://" + fully_qualified_namespace + "/"
props = {
"fully_qualified_namespace": fully_qualified_namespace,
"endpoint": endpoint,
"entity_path": entity,
"shared_access_signature": signature,
"shared_access_key_name": policy,
"shared_access_key": key,
}
return ServiceBusConnectionStringProperties(**props)
return ServiceBusConnectionStringProperties(
fully_qualified_namespace=fully_qualified_namespace,
endpoint=endpoint,
entity_path=entity,
shared_access_signature=signature,
shared_access_key_name=policy,
shared_access_key=key,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
import queue
from typing import TYPE_CHECKING, Union, Optional
from typing import TYPE_CHECKING, Union, Optional, Any

from .._servicebus_receiver import ServiceBusReceiver
from .._servicebus_session import ServiceBusSession
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
self._renew_tasks = queue.Queue() # type: ignore
self._infer_max_workers_time = 1

def __enter__(self):
def __enter__(self) -> "AutoLockRenewer":
if self._shutdown.is_set():
raise ServiceBusError(
"The AutoLockRenewer has already been shutdown. Please create a new instance for"
Expand All @@ -116,7 +116,7 @@ def __enter__(self):
self._init_workers()
return self

def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
self.close()

def _init_workers(self):
Expand Down Expand Up @@ -309,7 +309,7 @@ def register(
)
)

def close(self, wait=True):
def close(self, wait: bool = True) -> None:
"""Cease autorenewal by shutting down the thread pool to clean up any remaining lock renewal threads.
:param wait: Whether to block until thread pool has shutdown. Default is `True`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def __init__(
message: Union["Message", "pyamqp_Message"],
receive_mode: Union[ServiceBusReceiveMode, str] = ServiceBusReceiveMode.PEEK_LOCK,
frame: Optional["TransferFrame"] = None,
**kwargs
**kwargs: Any
) -> None:
self._amqp_transport = kwargs.pop("amqp_transport", PyamqpTransport)
super(ServiceBusReceivedMessage, self).__init__(None, message=message) # type: ignore
Expand All @@ -837,13 +837,13 @@ def __init__(
) from None
self._expiry: Optional[datetime.datetime] = None

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state['_receiver'] = None
state['_uamqp_message'] = None
return state

def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class ReceiverMixin(object): # pylint: disable=too-many-instance-attributes
def _populate_attributes(self, **kwargs):
self._amqp_transport: Union["AmqpTransport", "AmqpTransportAsync"]
self.entity_path: str
if kwargs.get("subscription_name"):
self._subscription_name = kwargs.get("subscription_name")
self._is_subscription = True
Expand All @@ -46,7 +47,7 @@ def _populate_attributes(self, **kwargs):
is_session=bool(self._session_id)
)

self._name = kwargs.get("client_identifier", "SBReceiver-{}".format(uuid.uuid4()))
self._name = kwargs.get("client_identifier") or "SBReceiver-{}".format(uuid.uuid4())
self._last_received_sequenced_number = None
self._message_iter = None
self._connection = kwargs.get("connection")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import struct
import uuid
import logging
from typing import List, Optional, Tuple, Dict, Callable, Any, cast, Union # pylint: disable=unused-import
from typing import List, Optional, Tuple, Dict, Callable, Any, cast, Union, TYPE_CHECKING


from .message import Message, Header, Properties

if TYPE_CHECKING:
from .message import MessageDict

_LOGGER = logging.getLogger(__name__)
_HEADER_PREFIX = memoryview(b'AMQP')
_COMPOSITES = {
Expand Down Expand Up @@ -276,7 +279,9 @@ def decode_payload(buffer):
message["footer"] = value
# TODO: we can possibly swap out the Message construct with a TypedDict
# for both input and output so we get the best of both.
return Message(**message)
# casting to TypedDict with named fields to allow for unpacking with **
message_properties = cast("MessageDict", message)
return Message(**message_properties)


def decode_frame(data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
except ImportError:
from typing_extensions import TypeAlias

from typing_extensions import Buffer


from .types import (
TYPE,
Expand Down Expand Up @@ -619,7 +621,7 @@ def encode_fields(value):


def encode_annotations(value):
# type: (Optional[Dict[str, Any]]) -> Dict[str, Any]
# type: (Optional[Dict[Union[str, bytes] , Any]]) -> Dict[str, Any]
"""The annotations type is a map where the keys are restricted to be of type symbol or of type ulong.
All ulong keys, and all symbolic keys except those beginning with "x-" are reserved.
Expand Down Expand Up @@ -650,8 +652,7 @@ def encode_annotations(value):
return fields


def encode_application_properties(value):
# type: (Optional[Dict[str, Any]]) -> Dict[str, Any]
def encode_application_properties(value: Optional[Dict[Union[str, bytes], Any]]) -> Dict[Union[str, bytes], Any]:
"""The application-properties section is a part of the bare message used for structured application data.
<type name="application-properties" class="restricted" source="map" provides="section">
Expand All @@ -668,7 +669,7 @@ def encode_application_properties(value):
"""
if not value:
return {TYPE: AMQPTypes.null, VALUE: None}
fields = {TYPE: AMQPTypes.map, VALUE: cast(List, [])}
fields: Dict[Union[str, bytes], Any] = {TYPE: AMQPTypes.map, VALUE: cast(List, [])}
for key, data in value.items():
cast(List, fields[VALUE]).append(({TYPE: AMQPTypes.string, VALUE: key}, data))
return fields
Expand Down Expand Up @@ -876,11 +877,11 @@ def describe_performative(performative):
body.append(
{
TYPE: AMQPTypes.array,
VALUE: [_FIELD_DEFINITIONS[field.type](v) for v in value],
VALUE: [_FIELD_DEFINITIONS[field.type](v) for v in value], # type: ignore
}
)
else:
body.append(_FIELD_DEFINITIONS[field.type](value))
body.append(_FIELD_DEFINITIONS[field.type](value)) # type: ignore
elif isinstance(field.type, ObjDefinition):
body.append(describe_performative(value))
else:
Expand Down Expand Up @@ -1033,7 +1034,8 @@ def encode_frame(frame, frame_type=_FRAME_TYPE):
frame_data = bytearray()
encode_value(frame_data, frame_description)
if isinstance(frame, performatives.TransferFrame):
frame_data += frame.payload
# casting from Optional[Buffer] since payload will not be None at this point
frame_data += cast(Buffer, frame.payload)

size = len(frame_data) + 8
header = size.to_bytes(4, "big") + _FRAME_OFFSET + frame_type
Expand Down

0 comments on commit f8f223d

Please sign in to comment.