From 7b053fd6d6a840b266c3b192ee9e288900a884f7 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Tue, 31 May 2022 14:07:58 +0530 Subject: [PATCH 01/13] Add Azure Service Bus Queue Operator's - Create Queue Operator - Send message queue Operator - Receive message queue Operator - Delete queue Operator - example DAG - Added hooks and connection type in - provider yaml file - Added unit Test case --- .../example_dags/example_service_bus_queue.py | 105 ++++++++ .../microsoft/azure/hooks/asb_admin_client.py | 90 +++++++ .../microsoft/azure/hooks/asb_message.py | 122 +++++++++ .../microsoft/azure/hooks/base_asb.py | 73 ++++++ .../operators/azure_service_bus_queue.py | 184 ++++++++++++++ .../providers/microsoft/azure/provider.yaml | 17 ++ .../connections/azure_service_bus.rst | 64 +++++ docs/integration-logos/azure/Service-Bus.svg | 1 + setup.py | 1 + .../azure/hooks/test_asb_admin_client.py | 104 ++++++++ .../microsoft/azure/hooks/test_asb_message.py | 150 ++++++++++++ .../operators/test_azure_service_queue.py | 231 ++++++++++++++++++ 12 files changed, 1142 insertions(+) create mode 100644 airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py create mode 100644 airflow/providers/microsoft/azure/hooks/asb_admin_client.py create mode 100644 airflow/providers/microsoft/azure/hooks/asb_message.py create mode 100644 airflow/providers/microsoft/azure/hooks/base_asb.py create mode 100644 airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py create mode 100644 docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst create mode 100644 docs/integration-logos/azure/Service-Bus.svg create mode 100644 tests/providers/microsoft/azure/hooks/test_asb_admin_client.py create mode 100644 tests/providers/microsoft/azure/hooks/test_asb_message.py create mode 100644 tests/providers/microsoft/azure/operators/test_azure_service_queue.py diff --git a/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py b/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py new file mode 100644 index 0000000000000..da689b82e1628 --- /dev/null +++ b/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.providers.microsoft.azure.operators.azure_service_bus_queue import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) + +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "azure_service_bus_conn_id": "azure_service_bus_default", +} + +CLIENT_ID = os.getenv("CLIENT_ID", "") +QUEUE_NAME = "sb_mgmt_queue_test" +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + +with DAG( + dag_id="example_azure_service_bus_queue", + start_date=datetime(2021, 8, 13), + schedule_interval=None, + catchup=False, + default_args=default_args, + tags=["example", "Azure service bus"], +) as dag: + # [START howto_operator_create_service_bus_queue] + create_service_bus_queue = AzureServiceBusCreateQueueOperator( + task_id="create_service_bus_queue", + queue_name=QUEUE_NAME, + ) + # [END howto_operator_create_service_bus_queue] + + # [START howto_operator_send_message_to_service_bus_queue] + send_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_message_to_service_bus_queue", + message=MESSAGE, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_message_to_service_bus_queue] + + # [START howto_operator_send_list_message_to_service_bus_queue] + send_list_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_list_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_list_message_to_service_bus_queue] + + # [START howto_operator_send_batch_message_to_service_bus_queue] + send_batch_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_batch_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=True, + ) + # [END howto_operator_send_batch_message_to_service_bus_queue] + + # [START howto_operator_receive_message_service_bus_queue] + receive_message_service_bus_queue = AzureServiceBusReceiveMessageOperator( + task_id="receive_message_service_bus_queue", + queue_name=QUEUE_NAME, + max_message_count=20, + max_wait_time=5, + ) + # [END howto_operator_receive_message_service_bus_queue] + + # [START howto_operator_delete_service_bus_queue] + delete_service_bus_queue = AzureServiceBusDeleteQueueOperator( + task_id="delete_service_bus_queue", queue_name=QUEUE_NAME, trigger_rule="all_done" + ) + # [END howto_operator_delete_service_bus_queue] + + ( + create_service_bus_queue + >> send_message_to_service_bus_queue + >> send_list_message_to_service_bus_queue + >> send_batch_message_to_service_bus_queue + >> receive_message_service_bus_queue + >> delete_service_bus_queue + ) diff --git a/airflow/providers/microsoft/azure/hooks/asb_admin_client.py b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py new file mode 100644 index 0000000000000..f56b596800996 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from azure.servicebus.management import QueueProperties, ServiceBusAdministrationClient + +from airflow.exceptions import AirflowBadRequest, AirflowException +from airflow.providers.microsoft.azure.hooks.base_asb import BaseAzureServiceBusHook + + +class AzureServiceBusAdminClientHook(BaseAzureServiceBusHook): + """ + Interacts with Azure ServiceBus management client + and Use this client to create, update, list, and delete resources of a ServiceBus namespace. + it uses the same azure service bus client connection inherits from the base class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def get_conn(self) -> ServiceBusAdministrationClient: + """Create and returns ServiceBusAdministration by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + + self.connection_string = str( + extras.get('connection_string') or extras.get('extra__azure_service_bus__connection_string') + ) + return ServiceBusAdministrationClient.from_connection_string(self.connection_string) + + def create_queue( + self, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + ) -> QueueProperties: + """ + Create Queue by connecting to service Bus Admin client return the QueueProperties + + :param queue_name: The name of the queue or a QueueProperties with name. + :param max_delivery_count: The maximum delivery count. A message is automatically + dead lettered after this number of deliveries. Default value is 10.. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + + try: + with self.get_conn() as service_mgmt_conn: + queue = service_mgmt_conn.create_queue( + queue_name, + max_delivery_count=max_delivery_count, + dead_lettering_on_message_expiration=dead_lettering_on_message_expiration, + enable_batched_operations=enable_batched_operations, + ) + return queue + except Exception as e: + raise AirflowException(e) + + def delete_queue(self, queue_name: str) -> None: + """ + Delete the queue by queue_name in service bus namespace + + :param queue_name: The name of the queue or a QueueProperties with name. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + + try: + with self.get_conn() as service_mgmt_conn: + service_mgmt_conn.delete_queue(queue_name) + except Exception as e: + raise AirflowException(e) diff --git a/airflow/providers/microsoft/azure/hooks/asb_message.py b/airflow/providers/microsoft/azure/hooks/asb_message.py new file mode 100644 index 0000000000000..b6f8e8931944c --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb_message.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional, Union + +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusSender +from azure.servicebus.exceptions import MessageSizeExceededError + +from airflow.exceptions import AirflowBadRequest, AirflowException +from airflow.providers.microsoft.azure.hooks.base_asb import BaseAzureServiceBusHook + + +class ServiceBusMessageHook(BaseAzureServiceBusHook): + """ + Interacts with ServiceBusClient and acts as a high level interface + for getting ServiceBusSender and ServiceBusReceiver. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def get_conn(self) -> ServiceBusClient: + """Create and returns ServiceBusClient by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + + self.connection_string = str( + extras.get('connection_string') or extras.get('extra__azure_service_bus__connection_string') + ) + + return ServiceBusClient.from_connection_string(conn_str=self.connection_string, logging_enable=True) + + def send_message( + self, queue_name: str, messages: Union[str, List[str]], batch_message_flag: bool = False + ): + """ + By using ServiceBusClient Send message(s) to a Service Bus Queue. By using + batch_message_flag it enables and send message as batch message + + :param queue_name: The name of the queue or a QueueProperties with name. + :param messages: Message which needs to be sent to the queue. It can be string or list of string. + :param batch_message_flag: bool flag, can be set to True if message needs to be sent as batch message. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + if not messages: + raise AirflowException("Message is empty.") + service_bus_client = self.get_conn() + try: + with service_bus_client: + sender = service_bus_client.get_queue_sender(queue_name=queue_name) + with sender: + if isinstance(messages, str): + if not batch_message_flag: + msg = ServiceBusMessage(messages) + sender.send_messages(msg) + else: + self.send_batch_message(sender, [messages]) + else: + if not batch_message_flag: + self.send_list_messages(sender, messages) + else: + self.send_batch_message(sender, messages) + except Exception as e: + raise AirflowException(e) + + @staticmethod + def send_list_messages(sender: ServiceBusSender, messages: List[str]): + list_messages = [ServiceBusMessage(message) for message in messages] + sender.send_messages(list_messages) # type: ignore[arg-type] + + @staticmethod + def send_batch_message(sender: ServiceBusSender, messages: List[str]): + batch_message = sender.create_message_batch() + for message in messages: + try: + batch_message.add_message(ServiceBusMessage(message)) + except MessageSizeExceededError as e: + # ServiceBusMessageBatch object reaches max_size. + # New ServiceBusMessageBatch object can be created here to send more data. + raise AirflowException(e) + sender.send_messages(batch_message) + + def receive_message( + self, queue_name, max_message_count: Optional[int] = 1, max_wait_time: Optional[float] = None + ): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue name or a QueueProperties with name. + :param max_message_count: Maximum number of messages in the batch. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + + service_bus_client = self.get_conn() + try: + with service_bus_client: + receiver = service_bus_client.get_queue_receiver(queue_name=queue_name) + with receiver: + received_msgs = receiver.receive_messages( + max_message_count=max_message_count, max_wait_time=max_wait_time + ) + for msg in received_msgs: + self.log.info(msg) + receiver.complete_message(msg) + except Exception as e: + raise AirflowException(e) diff --git a/airflow/providers/microsoft/azure/hooks/base_asb.py b/airflow/providers/microsoft/azure/hooks/base_asb.py new file mode 100644 index 0000000000000..84a0c20b6d629 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/base_asb.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional + +from airflow.hooks.base import BaseHook + + +class BaseAzureServiceBusHook(BaseHook): + """ + BaseAzureServiceBusHook class to session creation and connection creation. Client ID and + Secrete IDs are optional. + + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + conn_name_attr = 'azure_service_bus_conn_id' + default_conn_name = 'azure_service_bus_default' + conn_type = 'azure_service_bus' + hook_name = 'Azure ServiceBus' + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__azure_service_bus__connection_string": StringField( + lazy_gettext('Service Bus Connection String'), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'host', 'extra'], + "relabeling": { + 'login': 'Client ID', + 'password': 'Secret', + }, + "placeholders": { + 'login': 'Client ID (Optional)', + 'password': 'Client Secret (Optional)', + 'extra__azure_service_bus__connection_string': 'Service Bus Connection String', + }, + } + + def __init__(self, azure_service_bus_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_service_bus_conn_id + self._conn = None + self.connection_string: Optional[str] = None + + def get_conn(self): + return None diff --git a/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py b/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py new file mode 100644 index 0000000000000..31798600a0f54 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.asb_admin_client import AzureServiceBusAdminClientHook +from airflow.providers.microsoft.azure.hooks.asb_message import ServiceBusMessageHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AzureServiceBusCreateQueueOperator(BaseOperator): + """ + Creates a Azure ServiceBus queue under a ServiceBus Namespace by using ServiceBusAdministrationClient + + :param queue_name: The name of the queue. should be unique. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.max_delivery_count = max_delivery_count + self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + self.enable_batched_operations = enable_batched_operations + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Creates Queue in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook testing + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # create queue with name + queue = hook.create_queue( + self.queue_name, + self.max_delivery_count, + self.dead_lettering_on_message_expiration, + self.enable_batched_operations, + ) + self.log.info("Created Queue %s", queue.name) + self.log.info(queue) + + +class AzureServiceBusSendMessageOperator(BaseOperator): + """ + Send Message or batch message to the service bus queue + + :param message: Message which needs to be sent to the queue. It can be string or list of string. + :param batch: Its boolean flag by default it is set to False, if the message needs to be sent + as batch message it can be set to True. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name", "message") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + message: Union[str, List[str]], + batch: bool = False, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.batch = batch + self.message = message + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """ + Sends Message to the specific queue in Service Bus namespace, by + connecting to Service Bus client + """ + # Create the hook + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # send message + hook.send_message(self.queue_name, self.message, self.batch) + + +class AzureServiceBusReceiveMessageOperator(BaseOperator): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + max_message_count: Optional[int] = 10, + max_wait_time: Optional[float] = 5, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + self.max_message_count = max_message_count + self.max_wait_time = max_wait_time + + def execute(self, context: "Context") -> None: + """ + Receive Message in specific queue in Service Bus namespace, + by connecting to Service Bus client + """ + # Create the hook + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Receive message + hook.receive_message( + self.queue_name, max_message_count=self.max_message_count, max_wait_time=self.max_wait_time + ) + + +class AzureServiceBusDeleteQueueOperator(BaseOperator): + """ + Deletes the Queue in the Azure ServiceBus namespace + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Delete Queue in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # delete queue with name + hook.delete_queue(self.queue_name) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 10cab462b4c53..e4c024897ceea 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -88,6 +88,10 @@ integrations: external-doc-url: https://azure.microsoft.com/ logo: /integration-logos/azure/Microsoft-Azure.png tags: [azure] + - integration-name: Microsoft Azure Service Bus + external-doc-url: https://azure.microsoft.com/en-us/services/service-bus/ + logo: /integration-logos/azure/Service-Bus.svg + tags: [azure] operators: - integration-name: Microsoft Azure Data Lake Storage @@ -116,6 +120,9 @@ operators: - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.operators.data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.operators.azure_service_bus_queue sensors: - integration-name: Microsoft Azure Cosmos DB @@ -167,6 +174,11 @@ hooks: python-modules: - airflow.providers.microsoft.azure.hooks.data_factory - airflow.providers.microsoft.azure.hooks.azure_data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.hooks.base_asb + - airflow.providers.microsoft.azure.hooks.asb_admin_client + - airflow.providers.microsoft.azure.hooks.asb_message transfers: - source-integration-name: Local @@ -203,6 +215,9 @@ hook-class-names: # deprecated - to be removed after providers add dependency o - airflow.providers.microsoft.azure.hooks.wasb.WasbHook - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook + - airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook + - airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook + - airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook @@ -229,6 +244,8 @@ connection-types: - hook-class-name: >- airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-type: azure_container_registry + - hook-class-name: airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook + connection-type: azure_service_bus secrets-backends: - airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst b/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst new file mode 100644 index 0000000000000..638c993656d57 --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst @@ -0,0 +1,64 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:azure_service_bus: + +Microsoft Azure Service Bus +======================================= + +The Microsoft Azure Service Bus connection type enables the Azure Service Bus Integrations. + +Authenticating to Azure Service Bus +------------------------------------ + +There are two ways to authenticate and authorize access to Azure Service Bus resources: +Azure Active Directory (Azure AD) and Shared Access Signatures (SAS). + +1. Use `Azure Active Directory + `_ + i.e. add specific credentials (client_id, secret, tenant) and subscription id to the Airflow connection. +2. Use `Azure Shared access signature + `_ +3. Fallback on `DefaultAzureCredential + `_. + This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI... + +Default Connection IDs +---------------------- + +All hooks and operators related to Microsoft Azure Service Bus use ``azure_service_bus_default`` by default. + +Configuring the Connection +-------------------------- + +Client ID (optional) + Specify the ``client_id`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + It can be left out to fall back on ``DefaultAzureCredential``. + +Secret (optional) + Specify the ``secret`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + It can be left out to fall back on ``DefaultAzureCredential``. + +Connection ID + Specify the Azure Service bus connection string ID used for the initial connection. + `Get connection string + `_ + Use the key ``extra__azure_service_bus__connection_string`` to pass in the Connection ID . diff --git a/docs/integration-logos/azure/Service-Bus.svg b/docs/integration-logos/azure/Service-Bus.svg new file mode 100644 index 0000000000000..1604e04232630 --- /dev/null +++ b/docs/integration-logos/azure/Service-Bus.svg @@ -0,0 +1 @@ + diff --git a/setup.py b/setup.py index 3094dca782edd..198d812bb92a3 100644 --- a/setup.py +++ b/setup.py @@ -231,6 +231,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'azure-storage-blob>=12.7.0,<12.9.0', 'azure-storage-common>=2.1.0', 'azure-storage-file>=2.1.0', + 'azure-servicebus>=7.6.1', ] cassandra = [ 'cassandra-driver>=3.13.0', diff --git a/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py new file mode 100644 index 0000000000000..6521ed64e7699 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from unittest import mock + +import pytest +from azure.servicebus.management import ServiceBusAdministrationClient + +from airflow import AirflowException +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.asb_admin_client import AzureServiceBusAdminClientHook + + +class TestAzureServiceBusAdminClientHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.client_id = "test_client_id" + self.secret_key = "test_client_secret" + self.mock_conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({"connection_string": self.connection_string}), + ) + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client." + "AzureServiceBusAdminClientHook.get_connection" + ) + def test_get_conn(self, mock_connection): + mock_connection.return_value = self.mock_conn + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusAdministrationClient) + + @mock.patch('azure.servicebus.management.QueueProperties') + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_create_queue(self, mock_sb_admin_client, mock_queue_properties): + """ + Test `create_queue` hook function with mocking connection, queue properties value and + the azure service bus `create_queue` function + """ + mock_queue_properties.name = self.queue_name + mock_sb_admin_client.return_value.__enter__.return_value.create_queue.return_value = ( + mock_queue_properties + ) + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + response = hook.create_queue(self.queue_name) + assert response == mock_queue_properties + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_create_queue_exception(self, mock_sb_admin_client): + """ + Test `create_queue` functionality to raise AirflowException + by passing queue name as None and pytest raise Airflow Exception + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.create_queue(None) + + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_delete_queue(self, mock_sb_admin_client): + """ + Test Delete queue functionality by passing queue name, assert the function with values, + mock the azure service bus function `delete_queue` + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + hook.delete_queue(self.queue_name) + expected_calls = [mock.call().__enter__().delete_queue(self.queue_name)] + mock_sb_admin_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_delete_queue_exception(self, mock_sb_admin_client): + """ + Test `delete_queue` functionality to raise AirflowException, + by passing queue name as None and pytest raise Airflow Exception + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.delete_queue(None) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py new file mode 100644 index 0000000000000..24b86361265a8 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from unittest import mock + +import pytest +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusMessageBatch + +from airflow import AirflowException +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.asb_message import ServiceBusMessageHook + + +class TestServiceBusMessageHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.client_id = "test_client_id" + self.secret_key = "test_client_secret" + self.conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({'connection_string': self.connection_string}), + ) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_connection") + def test_get_service_bus_message_conn(self, mock_connection): + mock_connection.return_value = self.conn + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusClient) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_connection") + def test_get_conn_value_error(self, mock_connection): + mock_connection.return_value = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({"connection_string": "test connection"}), + ) + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(ValueError): + hook.get_conn() + + @pytest.mark.parametrize( + "mock_message, mock_batch_flag", + [ + ("Test message", True), + ("Test message", False), + (["Test message 1", "Test message 2"], True), + (["Test message 1", "Test message 2"], False), + ], + ) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.send_list_messages' + ) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.send_batch_message' + ) + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_send_message( + self, mock_sb_client, mock_batch_message, mock_list_message, mock_message, mock_batch_flag + ): + """ + Test `send_message` hook function with batch flag and message passed as mocked params, + which can be string or list of string, mock the azure service bus `send_messages` function + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id="azure_service_bus_default") + hook.send_message( + queue_name=self.queue_name, messages=mock_message, batch_message_flag=mock_batch_flag + ) + if isinstance(mock_message, list): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = [ServiceBusMessage(msg) for msg in mock_message] + elif isinstance(mock_message, str): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = ServiceBusMessage(mock_message) + + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(self.queue_name) + .__enter__() + .send_messages(message) + .__exit__() + ] + mock_sb_client.assert_has_calls(expected_calls, any_order=False) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_send_message_exception(self, mock_sb_client): + """ + Test `send_message` functionality to raise AirflowException in Azure ServiceBusMessageHook + by passing queue name as None + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.send_message(queue_name=None, messages="", batch_message_flag=False) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_message(self, mock_sb_client): + """ + Test `receive_message` hook function and assert the function with mock value, + mock the azure service bus `receive_messages` function + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + hook.receive_message(self.queue_name) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(self.queue_name) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .__iter__() + ] + mock_sb_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_message_exception(self, mock_sb_client): + """ + Test `receive_message` functionality to raise AirflowException in Azure ServiceBusMessageHook + by passing queue name as None + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.receive_message(None) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py new file mode 100644 index 0000000000000..4294fd42407d0 --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import unittest +from unittest import mock + +from azure.servicebus import ServiceBusMessage + +from airflow.models.dag import DAG +from airflow.providers.microsoft.azure.operators.azure_service_bus_queue import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +QUEUE_NAME = "test_queue" +OWNER_NAME = "airflow" +DAG_ID = "test_azure_service_bus_queue" + + +class TestAzureServiceBusCreateQueueOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusCreateQueueOperator with task id, queue_name and + asserting with value + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + assert asb_create_queue_operator.task_id == "asb_create_queue" + assert asb_create_queue_operator.queue_name == QUEUE_NAME + assert asb_create_queue_operator.max_delivery_count == 10 + assert asb_create_queue_operator.dead_lettering_on_message_expiration is True + assert asb_create_queue_operator.enable_batched_operations is True + + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue_test", + queue_name=QUEUE_NAME, + dag=self.dag, + max_delivery_count=10, + dead_lettering_on_message_expiration=False, + enable_batched_operations=False, + ) + assert asb_create_queue_operator.task_id == "asb_create_queue_test" + assert asb_create_queue_operator.queue_name == QUEUE_NAME + assert asb_create_queue_operator.max_delivery_count == 10 + assert asb_create_queue_operator.dead_lettering_on_message_expiration is False + assert asb_create_queue_operator.enable_batched_operations is False + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_create_queue(self, mock_get_conn): + """ + Test AzureServiceBusCreateQueueOperator passed with the queue name, + mocking the connection details, hook create_queue function + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue_operator", + queue_name=QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + dag=self.dag, + ) + asb_create_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.create_queue.assert_called_once_with( + QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + + +class TestAzureServiceBusDeleteQueueOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusDeleteQueueOperator with task id, queue_name and + asserting with values + """ + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + assert asb_delete_queue_operator.task_id == "asb_delete_queue" + assert asb_delete_queue_operator.queue_name == QUEUE_NAME + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_delete_queue(self, mock_get_conn): + """Test AzureServiceBusDeleteQueueOperator by mocking queue name, connection and hook delete_queue""" + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + asb_delete_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.delete_queue.assert_called_once_with(QUEUE_NAME) + + +class TestAzureServiceBusSendMessageOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusSendMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + msg = "test message" + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue_without_batch", + queue_name=QUEUE_NAME, + message=msg, + batch=False, + dag=self.dag, + ) + assert asb_send_message_queue_operator.task_id == "asb_send_message_queue_without_batch" + assert asb_send_message_queue_operator.queue_name == QUEUE_NAME + assert asb_send_message_queue_operator.message == msg + assert asb_send_message_queue_operator.batch is False + + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue_with_batch", + queue_name=QUEUE_NAME, + message=msg, + batch=True, + dag=self.dag, + ) + assert asb_send_message_queue_operator.task_id == "asb_send_message_queue_with_batch" + assert asb_send_message_queue_operator.queue_name == QUEUE_NAME + assert asb_send_message_queue_operator.message == msg + assert asb_send_message_queue_operator.batch is True + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn") + def test_send_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusSendMessageOperator with queue name, batch boolean flag, mock + the send_messages of azure service bus function + """ + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue", + queue_name=QUEUE_NAME, + message="Test message", + batch=False, + dag=self.dag, + ) + asb_send_message_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(QUEUE_NAME) + .__enter__() + .send_messages(ServiceBusMessage("Test message")) + .__exit__() + ] + mock_get_conn.assert_has_calls(expected_calls, any_order=False) + + +class TestAzureServiceBusReceiveMessageOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusReceiveMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + assert asb_receive_queue_operator.task_id == "asb_receive_message_queue" + assert asb_receive_queue_operator.queue_name == QUEUE_NAME + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn") + def test_receive_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusReceiveMessageOperator by mock connection, values + and the service bus receive message + """ + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + asb_receive_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(QUEUE_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .__iter__() + ] + mock_get_conn.assert_has_calls(expected_calls) From 290917bc0fce96b4d12ac7b27f43264357c2d3d6 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Tue, 31 May 2022 21:07:34 +0530 Subject: [PATCH 02/13] Fix tests --- .../microsoft/azure/hooks/test_asb_message.py | 14 ++++++++------ .../azure/operators/test_azure_service_queue.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py index 24b86361265a8..bbe58056f9817 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_message.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -130,12 +130,14 @@ def test_receive_message(self, mock_sb_client): hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) hook.receive_message(self.queue_name) expected_calls = [ - mock.call() - .__enter__() - .get_queue_receiver(self.queue_name) - .__enter__() - .receive_messages(max_message_count=10, max_wait_time=5) - .__iter__() + mock.call(), + mock.call().__enter__(), + mock.call().get_queue_receiver(queue_name='test_queue'), + mock.call().get_queue_receiver().__enter__(), + mock.call().get_queue_receiver().receive_messages(max_message_count=1, max_wait_time=None), + mock.call().get_queue_receiver().receive_messages().__iter__(), + mock.call().get_queue_receiver().__exit__(None, None, None), + mock.call().__exit__(None, None, None), ] mock_sb_client.assert_has_calls(expected_calls) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py index 4294fd42407d0..d307234462732 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py +++ b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py @@ -221,11 +221,13 @@ def test_receive_message_queue(self, mock_get_conn): ) asb_receive_queue_operator.execute(None) expected_calls = [ - mock.call() - .__enter__() - .get_queue_receiver(QUEUE_NAME) - .__enter__() - .receive_messages(max_message_count=10, max_wait_time=5) - .__iter__() + mock.call(), + mock.call().__enter__(), + mock.call().get_queue_receiver(queue_name='test_queue'), + mock.call().get_queue_receiver().__enter__(), + mock.call().get_queue_receiver().receive_messages(max_message_count=10, max_wait_time=5), + mock.call().get_queue_receiver().receive_messages().__iter__(), + mock.call().get_queue_receiver().__exit__(None, None, None), + mock.call().__exit__(None, None, None), ] mock_get_conn.assert_has_calls(expected_calls) From 3f313e311954881a2be0c0864ff838267a80af68 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Tue, 31 May 2022 21:27:56 +0530 Subject: [PATCH 03/13] Fix CircleCI failure --- airflow/providers/microsoft/azure/provider.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index e4c024897ceea..805fed9ab973c 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -216,8 +216,6 @@ hook-class-names: # deprecated - to be removed after providers add dependency o - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook - airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook - - airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook - - airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook From dddee2f63a91d9091af87d039d3af23c54516206 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Tue, 31 May 2022 23:18:16 +0530 Subject: [PATCH 04/13] Fix test case --- .../microsoft/azure/hooks/test_asb_message.py | 23 +++++++++++-------- .../operators/test_azure_service_queue.py | 17 +++++++------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py index bbe58056f9817..bc1d93c366009 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_message.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -121,23 +121,28 @@ def test_send_message_exception(self, mock_sb_client): with pytest.raises(AirflowException): hook.send_message(queue_name=None, messages="", batch_message_flag=False) + @mock.patch('azure.servicebus.ServiceBusMessage') @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') - def test_receive_message(self, mock_sb_client): + def test_receive_message(self, mock_sb_client, mock_service_bus_message): """ Test `receive_message` hook function and assert the function with mock value, mock the azure service bus `receive_messages` function """ hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [ + mock_service_bus_message + ] hook.receive_message(self.queue_name) expected_calls = [ - mock.call(), - mock.call().__enter__(), - mock.call().get_queue_receiver(queue_name='test_queue'), - mock.call().get_queue_receiver().__enter__(), - mock.call().get_queue_receiver().receive_messages(max_message_count=1, max_wait_time=None), - mock.call().get_queue_receiver().receive_messages().__iter__(), - mock.call().get_queue_receiver().__exit__(None, None, None), - mock.call().__exit__(None, None, None), + mock.call() + .__enter__() + .get_queue_receiver(self.queue_name) + .__enter__() + .receive_messages(max_message_count=30, max_wait_time=5) + .get_queue_receiver(self.queue_name) + .__exit__() + .mock_call() + .__exit__ ] mock_sb_client.assert_has_calls(expected_calls) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py index d307234462732..764d4d28d7920 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py +++ b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py @@ -221,13 +221,14 @@ def test_receive_message_queue(self, mock_get_conn): ) asb_receive_queue_operator.execute(None) expected_calls = [ - mock.call(), - mock.call().__enter__(), - mock.call().get_queue_receiver(queue_name='test_queue'), - mock.call().get_queue_receiver().__enter__(), - mock.call().get_queue_receiver().receive_messages(max_message_count=10, max_wait_time=5), - mock.call().get_queue_receiver().receive_messages().__iter__(), - mock.call().get_queue_receiver().__exit__(None, None, None), - mock.call().__exit__(None, None, None), + mock.call() + .__enter__() + .get_queue_receiver(QUEUE_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .get_queue_receiver(QUEUE_NAME) + .__exit__() + .mock_call() + .__exit__ ] mock_get_conn.assert_has_calls(expected_calls) From d6d156e8f23b1ed574c3048f4660af7dc2626f18 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Tue, 31 May 2022 14:07:58 +0530 Subject: [PATCH 05/13] Add Azure Service Bus Queue Operator's - Create Queue Operator - Send message queue Operator - Receive message queue Operator - Delete queue Operator - example DAG - Added hooks and connection type in - provider yaml file - Added unit Test case --- .../example_dags/example_service_bus_queue.py | 105 ++++++++ .../microsoft/azure/hooks/asb_admin_client.py | 90 +++++++ .../microsoft/azure/hooks/asb_message.py | 122 +++++++++ .../microsoft/azure/hooks/base_asb.py | 73 ++++++ .../operators/azure_service_bus_queue.py | 184 ++++++++++++++ .../providers/microsoft/azure/provider.yaml | 17 ++ .../connections/azure_service_bus.rst | 64 +++++ docs/integration-logos/azure/Service-Bus.svg | 1 + setup.py | 1 + .../azure/hooks/test_asb_admin_client.py | 104 ++++++++ .../microsoft/azure/hooks/test_asb_message.py | 150 ++++++++++++ .../operators/test_azure_service_queue.py | 231 ++++++++++++++++++ 12 files changed, 1142 insertions(+) create mode 100644 airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py create mode 100644 airflow/providers/microsoft/azure/hooks/asb_admin_client.py create mode 100644 airflow/providers/microsoft/azure/hooks/asb_message.py create mode 100644 airflow/providers/microsoft/azure/hooks/base_asb.py create mode 100644 airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py create mode 100644 docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst create mode 100644 docs/integration-logos/azure/Service-Bus.svg create mode 100644 tests/providers/microsoft/azure/hooks/test_asb_admin_client.py create mode 100644 tests/providers/microsoft/azure/hooks/test_asb_message.py create mode 100644 tests/providers/microsoft/azure/operators/test_azure_service_queue.py diff --git a/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py b/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py new file mode 100644 index 0000000000000..da689b82e1628 --- /dev/null +++ b/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.providers.microsoft.azure.operators.azure_service_bus_queue import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) + +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "azure_service_bus_conn_id": "azure_service_bus_default", +} + +CLIENT_ID = os.getenv("CLIENT_ID", "") +QUEUE_NAME = "sb_mgmt_queue_test" +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + +with DAG( + dag_id="example_azure_service_bus_queue", + start_date=datetime(2021, 8, 13), + schedule_interval=None, + catchup=False, + default_args=default_args, + tags=["example", "Azure service bus"], +) as dag: + # [START howto_operator_create_service_bus_queue] + create_service_bus_queue = AzureServiceBusCreateQueueOperator( + task_id="create_service_bus_queue", + queue_name=QUEUE_NAME, + ) + # [END howto_operator_create_service_bus_queue] + + # [START howto_operator_send_message_to_service_bus_queue] + send_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_message_to_service_bus_queue", + message=MESSAGE, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_message_to_service_bus_queue] + + # [START howto_operator_send_list_message_to_service_bus_queue] + send_list_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_list_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_list_message_to_service_bus_queue] + + # [START howto_operator_send_batch_message_to_service_bus_queue] + send_batch_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_batch_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=True, + ) + # [END howto_operator_send_batch_message_to_service_bus_queue] + + # [START howto_operator_receive_message_service_bus_queue] + receive_message_service_bus_queue = AzureServiceBusReceiveMessageOperator( + task_id="receive_message_service_bus_queue", + queue_name=QUEUE_NAME, + max_message_count=20, + max_wait_time=5, + ) + # [END howto_operator_receive_message_service_bus_queue] + + # [START howto_operator_delete_service_bus_queue] + delete_service_bus_queue = AzureServiceBusDeleteQueueOperator( + task_id="delete_service_bus_queue", queue_name=QUEUE_NAME, trigger_rule="all_done" + ) + # [END howto_operator_delete_service_bus_queue] + + ( + create_service_bus_queue + >> send_message_to_service_bus_queue + >> send_list_message_to_service_bus_queue + >> send_batch_message_to_service_bus_queue + >> receive_message_service_bus_queue + >> delete_service_bus_queue + ) diff --git a/airflow/providers/microsoft/azure/hooks/asb_admin_client.py b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py new file mode 100644 index 0000000000000..f56b596800996 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from azure.servicebus.management import QueueProperties, ServiceBusAdministrationClient + +from airflow.exceptions import AirflowBadRequest, AirflowException +from airflow.providers.microsoft.azure.hooks.base_asb import BaseAzureServiceBusHook + + +class AzureServiceBusAdminClientHook(BaseAzureServiceBusHook): + """ + Interacts with Azure ServiceBus management client + and Use this client to create, update, list, and delete resources of a ServiceBus namespace. + it uses the same azure service bus client connection inherits from the base class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def get_conn(self) -> ServiceBusAdministrationClient: + """Create and returns ServiceBusAdministration by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + + self.connection_string = str( + extras.get('connection_string') or extras.get('extra__azure_service_bus__connection_string') + ) + return ServiceBusAdministrationClient.from_connection_string(self.connection_string) + + def create_queue( + self, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + ) -> QueueProperties: + """ + Create Queue by connecting to service Bus Admin client return the QueueProperties + + :param queue_name: The name of the queue or a QueueProperties with name. + :param max_delivery_count: The maximum delivery count. A message is automatically + dead lettered after this number of deliveries. Default value is 10.. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + + try: + with self.get_conn() as service_mgmt_conn: + queue = service_mgmt_conn.create_queue( + queue_name, + max_delivery_count=max_delivery_count, + dead_lettering_on_message_expiration=dead_lettering_on_message_expiration, + enable_batched_operations=enable_batched_operations, + ) + return queue + except Exception as e: + raise AirflowException(e) + + def delete_queue(self, queue_name: str) -> None: + """ + Delete the queue by queue_name in service bus namespace + + :param queue_name: The name of the queue or a QueueProperties with name. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + + try: + with self.get_conn() as service_mgmt_conn: + service_mgmt_conn.delete_queue(queue_name) + except Exception as e: + raise AirflowException(e) diff --git a/airflow/providers/microsoft/azure/hooks/asb_message.py b/airflow/providers/microsoft/azure/hooks/asb_message.py new file mode 100644 index 0000000000000..b6f8e8931944c --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb_message.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional, Union + +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusSender +from azure.servicebus.exceptions import MessageSizeExceededError + +from airflow.exceptions import AirflowBadRequest, AirflowException +from airflow.providers.microsoft.azure.hooks.base_asb import BaseAzureServiceBusHook + + +class ServiceBusMessageHook(BaseAzureServiceBusHook): + """ + Interacts with ServiceBusClient and acts as a high level interface + for getting ServiceBusSender and ServiceBusReceiver. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def get_conn(self) -> ServiceBusClient: + """Create and returns ServiceBusClient by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + + self.connection_string = str( + extras.get('connection_string') or extras.get('extra__azure_service_bus__connection_string') + ) + + return ServiceBusClient.from_connection_string(conn_str=self.connection_string, logging_enable=True) + + def send_message( + self, queue_name: str, messages: Union[str, List[str]], batch_message_flag: bool = False + ): + """ + By using ServiceBusClient Send message(s) to a Service Bus Queue. By using + batch_message_flag it enables and send message as batch message + + :param queue_name: The name of the queue or a QueueProperties with name. + :param messages: Message which needs to be sent to the queue. It can be string or list of string. + :param batch_message_flag: bool flag, can be set to True if message needs to be sent as batch message. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + if not messages: + raise AirflowException("Message is empty.") + service_bus_client = self.get_conn() + try: + with service_bus_client: + sender = service_bus_client.get_queue_sender(queue_name=queue_name) + with sender: + if isinstance(messages, str): + if not batch_message_flag: + msg = ServiceBusMessage(messages) + sender.send_messages(msg) + else: + self.send_batch_message(sender, [messages]) + else: + if not batch_message_flag: + self.send_list_messages(sender, messages) + else: + self.send_batch_message(sender, messages) + except Exception as e: + raise AirflowException(e) + + @staticmethod + def send_list_messages(sender: ServiceBusSender, messages: List[str]): + list_messages = [ServiceBusMessage(message) for message in messages] + sender.send_messages(list_messages) # type: ignore[arg-type] + + @staticmethod + def send_batch_message(sender: ServiceBusSender, messages: List[str]): + batch_message = sender.create_message_batch() + for message in messages: + try: + batch_message.add_message(ServiceBusMessage(message)) + except MessageSizeExceededError as e: + # ServiceBusMessageBatch object reaches max_size. + # New ServiceBusMessageBatch object can be created here to send more data. + raise AirflowException(e) + sender.send_messages(batch_message) + + def receive_message( + self, queue_name, max_message_count: Optional[int] = 1, max_wait_time: Optional[float] = None + ): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue name or a QueueProperties with name. + :param max_message_count: Maximum number of messages in the batch. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. + """ + if queue_name is None: + raise AirflowBadRequest("Queue name cannot be None.") + + service_bus_client = self.get_conn() + try: + with service_bus_client: + receiver = service_bus_client.get_queue_receiver(queue_name=queue_name) + with receiver: + received_msgs = receiver.receive_messages( + max_message_count=max_message_count, max_wait_time=max_wait_time + ) + for msg in received_msgs: + self.log.info(msg) + receiver.complete_message(msg) + except Exception as e: + raise AirflowException(e) diff --git a/airflow/providers/microsoft/azure/hooks/base_asb.py b/airflow/providers/microsoft/azure/hooks/base_asb.py new file mode 100644 index 0000000000000..84a0c20b6d629 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/base_asb.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional + +from airflow.hooks.base import BaseHook + + +class BaseAzureServiceBusHook(BaseHook): + """ + BaseAzureServiceBusHook class to session creation and connection creation. Client ID and + Secrete IDs are optional. + + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + conn_name_attr = 'azure_service_bus_conn_id' + default_conn_name = 'azure_service_bus_default' + conn_type = 'azure_service_bus' + hook_name = 'Azure ServiceBus' + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__azure_service_bus__connection_string": StringField( + lazy_gettext('Service Bus Connection String'), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'host', 'extra'], + "relabeling": { + 'login': 'Client ID', + 'password': 'Secret', + }, + "placeholders": { + 'login': 'Client ID (Optional)', + 'password': 'Client Secret (Optional)', + 'extra__azure_service_bus__connection_string': 'Service Bus Connection String', + }, + } + + def __init__(self, azure_service_bus_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_service_bus_conn_id + self._conn = None + self.connection_string: Optional[str] = None + + def get_conn(self): + return None diff --git a/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py b/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py new file mode 100644 index 0000000000000..31798600a0f54 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.asb_admin_client import AzureServiceBusAdminClientHook +from airflow.providers.microsoft.azure.hooks.asb_message import ServiceBusMessageHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AzureServiceBusCreateQueueOperator(BaseOperator): + """ + Creates a Azure ServiceBus queue under a ServiceBus Namespace by using ServiceBusAdministrationClient + + :param queue_name: The name of the queue. should be unique. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.max_delivery_count = max_delivery_count + self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + self.enable_batched_operations = enable_batched_operations + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Creates Queue in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook testing + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # create queue with name + queue = hook.create_queue( + self.queue_name, + self.max_delivery_count, + self.dead_lettering_on_message_expiration, + self.enable_batched_operations, + ) + self.log.info("Created Queue %s", queue.name) + self.log.info(queue) + + +class AzureServiceBusSendMessageOperator(BaseOperator): + """ + Send Message or batch message to the service bus queue + + :param message: Message which needs to be sent to the queue. It can be string or list of string. + :param batch: Its boolean flag by default it is set to False, if the message needs to be sent + as batch message it can be set to True. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name", "message") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + message: Union[str, List[str]], + batch: bool = False, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.batch = batch + self.message = message + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """ + Sends Message to the specific queue in Service Bus namespace, by + connecting to Service Bus client + """ + # Create the hook + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # send message + hook.send_message(self.queue_name, self.message, self.batch) + + +class AzureServiceBusReceiveMessageOperator(BaseOperator): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + max_message_count: Optional[int] = 10, + max_wait_time: Optional[float] = 5, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + self.max_message_count = max_message_count + self.max_wait_time = max_wait_time + + def execute(self, context: "Context") -> None: + """ + Receive Message in specific queue in Service Bus namespace, + by connecting to Service Bus client + """ + # Create the hook + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Receive message + hook.receive_message( + self.queue_name, max_message_count=self.max_message_count, max_wait_time=self.max_wait_time + ) + + +class AzureServiceBusDeleteQueueOperator(BaseOperator): + """ + Deletes the Queue in the Azure ServiceBus namespace + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Delete Queue in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # delete queue with name + hook.delete_queue(self.queue_name) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 10cab462b4c53..e4c024897ceea 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -88,6 +88,10 @@ integrations: external-doc-url: https://azure.microsoft.com/ logo: /integration-logos/azure/Microsoft-Azure.png tags: [azure] + - integration-name: Microsoft Azure Service Bus + external-doc-url: https://azure.microsoft.com/en-us/services/service-bus/ + logo: /integration-logos/azure/Service-Bus.svg + tags: [azure] operators: - integration-name: Microsoft Azure Data Lake Storage @@ -116,6 +120,9 @@ operators: - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.operators.data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.operators.azure_service_bus_queue sensors: - integration-name: Microsoft Azure Cosmos DB @@ -167,6 +174,11 @@ hooks: python-modules: - airflow.providers.microsoft.azure.hooks.data_factory - airflow.providers.microsoft.azure.hooks.azure_data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.hooks.base_asb + - airflow.providers.microsoft.azure.hooks.asb_admin_client + - airflow.providers.microsoft.azure.hooks.asb_message transfers: - source-integration-name: Local @@ -203,6 +215,9 @@ hook-class-names: # deprecated - to be removed after providers add dependency o - airflow.providers.microsoft.azure.hooks.wasb.WasbHook - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook + - airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook + - airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook + - airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook @@ -229,6 +244,8 @@ connection-types: - hook-class-name: >- airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-type: azure_container_registry + - hook-class-name: airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook + connection-type: azure_service_bus secrets-backends: - airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst b/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst new file mode 100644 index 0000000000000..638c993656d57 --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst @@ -0,0 +1,64 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:azure_service_bus: + +Microsoft Azure Service Bus +======================================= + +The Microsoft Azure Service Bus connection type enables the Azure Service Bus Integrations. + +Authenticating to Azure Service Bus +------------------------------------ + +There are two ways to authenticate and authorize access to Azure Service Bus resources: +Azure Active Directory (Azure AD) and Shared Access Signatures (SAS). + +1. Use `Azure Active Directory + `_ + i.e. add specific credentials (client_id, secret, tenant) and subscription id to the Airflow connection. +2. Use `Azure Shared access signature + `_ +3. Fallback on `DefaultAzureCredential + `_. + This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI... + +Default Connection IDs +---------------------- + +All hooks and operators related to Microsoft Azure Service Bus use ``azure_service_bus_default`` by default. + +Configuring the Connection +-------------------------- + +Client ID (optional) + Specify the ``client_id`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + It can be left out to fall back on ``DefaultAzureCredential``. + +Secret (optional) + Specify the ``secret`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + It can be left out to fall back on ``DefaultAzureCredential``. + +Connection ID + Specify the Azure Service bus connection string ID used for the initial connection. + `Get connection string + `_ + Use the key ``extra__azure_service_bus__connection_string`` to pass in the Connection ID . diff --git a/docs/integration-logos/azure/Service-Bus.svg b/docs/integration-logos/azure/Service-Bus.svg new file mode 100644 index 0000000000000..1604e04232630 --- /dev/null +++ b/docs/integration-logos/azure/Service-Bus.svg @@ -0,0 +1 @@ + diff --git a/setup.py b/setup.py index e73a88d6e7d8e..022f716d17153 100644 --- a/setup.py +++ b/setup.py @@ -231,6 +231,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'azure-storage-blob>=12.7.0,<12.9.0', 'azure-storage-common>=2.1.0', 'azure-storage-file>=2.1.0', + 'azure-servicebus>=7.6.1', ] cassandra = [ 'cassandra-driver>=3.13.0', diff --git a/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py new file mode 100644 index 0000000000000..6521ed64e7699 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from unittest import mock + +import pytest +from azure.servicebus.management import ServiceBusAdministrationClient + +from airflow import AirflowException +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.asb_admin_client import AzureServiceBusAdminClientHook + + +class TestAzureServiceBusAdminClientHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.client_id = "test_client_id" + self.secret_key = "test_client_secret" + self.mock_conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({"connection_string": self.connection_string}), + ) + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client." + "AzureServiceBusAdminClientHook.get_connection" + ) + def test_get_conn(self, mock_connection): + mock_connection.return_value = self.mock_conn + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusAdministrationClient) + + @mock.patch('azure.servicebus.management.QueueProperties') + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_create_queue(self, mock_sb_admin_client, mock_queue_properties): + """ + Test `create_queue` hook function with mocking connection, queue properties value and + the azure service bus `create_queue` function + """ + mock_queue_properties.name = self.queue_name + mock_sb_admin_client.return_value.__enter__.return_value.create_queue.return_value = ( + mock_queue_properties + ) + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + response = hook.create_queue(self.queue_name) + assert response == mock_queue_properties + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_create_queue_exception(self, mock_sb_admin_client): + """ + Test `create_queue` functionality to raise AirflowException + by passing queue name as None and pytest raise Airflow Exception + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.create_queue(None) + + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_delete_queue(self, mock_sb_admin_client): + """ + Test Delete queue functionality by passing queue name, assert the function with values, + mock the azure service bus function `delete_queue` + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + hook.delete_queue(self.queue_name) + expected_calls = [mock.call().__enter__().delete_queue(self.queue_name)] + mock_sb_admin_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_delete_queue_exception(self, mock_sb_admin_client): + """ + Test `delete_queue` functionality to raise AirflowException, + by passing queue name as None and pytest raise Airflow Exception + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.delete_queue(None) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py new file mode 100644 index 0000000000000..24b86361265a8 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from unittest import mock + +import pytest +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusMessageBatch + +from airflow import AirflowException +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.asb_message import ServiceBusMessageHook + + +class TestServiceBusMessageHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.client_id = "test_client_id" + self.secret_key = "test_client_secret" + self.conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({'connection_string': self.connection_string}), + ) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_connection") + def test_get_service_bus_message_conn(self, mock_connection): + mock_connection.return_value = self.conn + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusClient) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_connection") + def test_get_conn_value_error(self, mock_connection): + mock_connection.return_value = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({"connection_string": "test connection"}), + ) + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(ValueError): + hook.get_conn() + + @pytest.mark.parametrize( + "mock_message, mock_batch_flag", + [ + ("Test message", True), + ("Test message", False), + (["Test message 1", "Test message 2"], True), + (["Test message 1", "Test message 2"], False), + ], + ) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.send_list_messages' + ) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.send_batch_message' + ) + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_send_message( + self, mock_sb_client, mock_batch_message, mock_list_message, mock_message, mock_batch_flag + ): + """ + Test `send_message` hook function with batch flag and message passed as mocked params, + which can be string or list of string, mock the azure service bus `send_messages` function + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id="azure_service_bus_default") + hook.send_message( + queue_name=self.queue_name, messages=mock_message, batch_message_flag=mock_batch_flag + ) + if isinstance(mock_message, list): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = [ServiceBusMessage(msg) for msg in mock_message] + elif isinstance(mock_message, str): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = ServiceBusMessage(mock_message) + + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(self.queue_name) + .__enter__() + .send_messages(message) + .__exit__() + ] + mock_sb_client.assert_has_calls(expected_calls, any_order=False) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_send_message_exception(self, mock_sb_client): + """ + Test `send_message` functionality to raise AirflowException in Azure ServiceBusMessageHook + by passing queue name as None + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.send_message(queue_name=None, messages="", batch_message_flag=False) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_message(self, mock_sb_client): + """ + Test `receive_message` hook function and assert the function with mock value, + mock the azure service bus `receive_messages` function + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + hook.receive_message(self.queue_name) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(self.queue_name) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .__iter__() + ] + mock_sb_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_message_exception(self, mock_sb_client): + """ + Test `receive_message` functionality to raise AirflowException in Azure ServiceBusMessageHook + by passing queue name as None + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.receive_message(None) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py new file mode 100644 index 0000000000000..4294fd42407d0 --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import unittest +from unittest import mock + +from azure.servicebus import ServiceBusMessage + +from airflow.models.dag import DAG +from airflow.providers.microsoft.azure.operators.azure_service_bus_queue import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +QUEUE_NAME = "test_queue" +OWNER_NAME = "airflow" +DAG_ID = "test_azure_service_bus_queue" + + +class TestAzureServiceBusCreateQueueOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusCreateQueueOperator with task id, queue_name and + asserting with value + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + assert asb_create_queue_operator.task_id == "asb_create_queue" + assert asb_create_queue_operator.queue_name == QUEUE_NAME + assert asb_create_queue_operator.max_delivery_count == 10 + assert asb_create_queue_operator.dead_lettering_on_message_expiration is True + assert asb_create_queue_operator.enable_batched_operations is True + + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue_test", + queue_name=QUEUE_NAME, + dag=self.dag, + max_delivery_count=10, + dead_lettering_on_message_expiration=False, + enable_batched_operations=False, + ) + assert asb_create_queue_operator.task_id == "asb_create_queue_test" + assert asb_create_queue_operator.queue_name == QUEUE_NAME + assert asb_create_queue_operator.max_delivery_count == 10 + assert asb_create_queue_operator.dead_lettering_on_message_expiration is False + assert asb_create_queue_operator.enable_batched_operations is False + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_create_queue(self, mock_get_conn): + """ + Test AzureServiceBusCreateQueueOperator passed with the queue name, + mocking the connection details, hook create_queue function + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue_operator", + queue_name=QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + dag=self.dag, + ) + asb_create_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.create_queue.assert_called_once_with( + QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + + +class TestAzureServiceBusDeleteQueueOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusDeleteQueueOperator with task id, queue_name and + asserting with values + """ + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + assert asb_delete_queue_operator.task_id == "asb_delete_queue" + assert asb_delete_queue_operator.queue_name == QUEUE_NAME + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_delete_queue(self, mock_get_conn): + """Test AzureServiceBusDeleteQueueOperator by mocking queue name, connection and hook delete_queue""" + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + asb_delete_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.delete_queue.assert_called_once_with(QUEUE_NAME) + + +class TestAzureServiceBusSendMessageOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusSendMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + msg = "test message" + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue_without_batch", + queue_name=QUEUE_NAME, + message=msg, + batch=False, + dag=self.dag, + ) + assert asb_send_message_queue_operator.task_id == "asb_send_message_queue_without_batch" + assert asb_send_message_queue_operator.queue_name == QUEUE_NAME + assert asb_send_message_queue_operator.message == msg + assert asb_send_message_queue_operator.batch is False + + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue_with_batch", + queue_name=QUEUE_NAME, + message=msg, + batch=True, + dag=self.dag, + ) + assert asb_send_message_queue_operator.task_id == "asb_send_message_queue_with_batch" + assert asb_send_message_queue_operator.queue_name == QUEUE_NAME + assert asb_send_message_queue_operator.message == msg + assert asb_send_message_queue_operator.batch is True + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn") + def test_send_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusSendMessageOperator with queue name, batch boolean flag, mock + the send_messages of azure service bus function + """ + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue", + queue_name=QUEUE_NAME, + message="Test message", + batch=False, + dag=self.dag, + ) + asb_send_message_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(QUEUE_NAME) + .__enter__() + .send_messages(ServiceBusMessage("Test message")) + .__exit__() + ] + mock_get_conn.assert_has_calls(expected_calls, any_order=False) + + +class TestAzureServiceBusReceiveMessageOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusReceiveMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + assert asb_receive_queue_operator.task_id == "asb_receive_message_queue" + assert asb_receive_queue_operator.queue_name == QUEUE_NAME + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn") + def test_receive_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusReceiveMessageOperator by mock connection, values + and the service bus receive message + """ + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + dag=self.dag, + ) + asb_receive_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(QUEUE_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .__iter__() + ] + mock_get_conn.assert_has_calls(expected_calls) From 1b961566b6b0eb1d55ce558d43aec8126989099c Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Tue, 31 May 2022 21:27:56 +0530 Subject: [PATCH 06/13] Fix CircleCI failure --- airflow/providers/microsoft/azure/provider.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index e4c024897ceea..805fed9ab973c 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -216,8 +216,6 @@ hook-class-names: # deprecated - to be removed after providers add dependency o - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook - airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook - - airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook - - airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook From 68f1ec5954f044385453612653c37da078124f67 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Tue, 31 May 2022 21:07:34 +0530 Subject: [PATCH 07/13] Fix tests --- .../microsoft/azure/hooks/test_asb_message.py | 14 ++++++++------ .../azure/operators/test_azure_service_queue.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py index 24b86361265a8..bbe58056f9817 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_message.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -130,12 +130,14 @@ def test_receive_message(self, mock_sb_client): hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) hook.receive_message(self.queue_name) expected_calls = [ - mock.call() - .__enter__() - .get_queue_receiver(self.queue_name) - .__enter__() - .receive_messages(max_message_count=10, max_wait_time=5) - .__iter__() + mock.call(), + mock.call().__enter__(), + mock.call().get_queue_receiver(queue_name='test_queue'), + mock.call().get_queue_receiver().__enter__(), + mock.call().get_queue_receiver().receive_messages(max_message_count=1, max_wait_time=None), + mock.call().get_queue_receiver().receive_messages().__iter__(), + mock.call().get_queue_receiver().__exit__(None, None, None), + mock.call().__exit__(None, None, None), ] mock_sb_client.assert_has_calls(expected_calls) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py index 4294fd42407d0..d307234462732 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py +++ b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py @@ -221,11 +221,13 @@ def test_receive_message_queue(self, mock_get_conn): ) asb_receive_queue_operator.execute(None) expected_calls = [ - mock.call() - .__enter__() - .get_queue_receiver(QUEUE_NAME) - .__enter__() - .receive_messages(max_message_count=10, max_wait_time=5) - .__iter__() + mock.call(), + mock.call().__enter__(), + mock.call().get_queue_receiver(queue_name='test_queue'), + mock.call().get_queue_receiver().__enter__(), + mock.call().get_queue_receiver().receive_messages(max_message_count=10, max_wait_time=5), + mock.call().get_queue_receiver().receive_messages().__iter__(), + mock.call().get_queue_receiver().__exit__(None, None, None), + mock.call().__exit__(None, None, None), ] mock_get_conn.assert_has_calls(expected_calls) From b20d7828481c56edcaf75d7fe8c39a33ea30ffb9 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Tue, 31 May 2022 23:18:16 +0530 Subject: [PATCH 08/13] Fix test case --- .../microsoft/azure/hooks/test_asb_message.py | 23 +++++++++++-------- .../operators/test_azure_service_queue.py | 17 +++++++------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py index bbe58056f9817..bc1d93c366009 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_message.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -121,23 +121,28 @@ def test_send_message_exception(self, mock_sb_client): with pytest.raises(AirflowException): hook.send_message(queue_name=None, messages="", batch_message_flag=False) + @mock.patch('azure.servicebus.ServiceBusMessage') @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') - def test_receive_message(self, mock_sb_client): + def test_receive_message(self, mock_sb_client, mock_service_bus_message): """ Test `receive_message` hook function and assert the function with mock value, mock the azure service bus `receive_messages` function """ hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [ + mock_service_bus_message + ] hook.receive_message(self.queue_name) expected_calls = [ - mock.call(), - mock.call().__enter__(), - mock.call().get_queue_receiver(queue_name='test_queue'), - mock.call().get_queue_receiver().__enter__(), - mock.call().get_queue_receiver().receive_messages(max_message_count=1, max_wait_time=None), - mock.call().get_queue_receiver().receive_messages().__iter__(), - mock.call().get_queue_receiver().__exit__(None, None, None), - mock.call().__exit__(None, None, None), + mock.call() + .__enter__() + .get_queue_receiver(self.queue_name) + .__enter__() + .receive_messages(max_message_count=30, max_wait_time=5) + .get_queue_receiver(self.queue_name) + .__exit__() + .mock_call() + .__exit__ ] mock_sb_client.assert_has_calls(expected_calls) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py index d307234462732..764d4d28d7920 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py +++ b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py @@ -221,13 +221,14 @@ def test_receive_message_queue(self, mock_get_conn): ) asb_receive_queue_operator.execute(None) expected_calls = [ - mock.call(), - mock.call().__enter__(), - mock.call().get_queue_receiver(queue_name='test_queue'), - mock.call().get_queue_receiver().__enter__(), - mock.call().get_queue_receiver().receive_messages(max_message_count=10, max_wait_time=5), - mock.call().get_queue_receiver().receive_messages().__iter__(), - mock.call().get_queue_receiver().__exit__(None, None, None), - mock.call().__exit__(None, None, None), + mock.call() + .__enter__() + .get_queue_receiver(QUEUE_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .get_queue_receiver(QUEUE_NAME) + .__exit__() + .mock_call() + .__exit__ ] mock_get_conn.assert_has_calls(expected_calls) From 9b20cd5fe15dbb14428fa96cf0ea31c2470713a6 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Wed, 1 Jun 2022 08:53:41 +0530 Subject: [PATCH 09/13] Implemented Azure service bus subscription Operators - Added ASBCreateSubscriptionOperator - Added ASBUpdateSubscriptionOperator - Added ASBReceiveSubscriptionMessageOperator - Added ASBDeleteSubscriptionOperator - example DAG - Added hooks - Added unit Test case --- .../example_service_bus_subscription.py | 88 +++++++ .../microsoft/azure/hooks/asb_admin_client.py | 126 +++++++++- .../microsoft/azure/hooks/asb_message.py | 40 ++++ .../azure/hooks/test_asb_admin_client.py | 134 +++++++++++ .../microsoft/azure/hooks/test_asb_message.py | 43 ++++ .../test_azure_service_subscription.py | 223 ++++++++++++++++++ 6 files changed, 653 insertions(+), 1 deletion(-) create mode 100644 airflow/providers/microsoft/azure/example_dags/example_service_bus_subscription.py create mode 100644 tests/providers/microsoft/azure/operators/test_azure_service_subscription.py diff --git a/airflow/providers/microsoft/azure/example_dags/example_service_bus_subscription.py b/airflow/providers/microsoft/azure/example_dags/example_service_bus_subscription.py new file mode 100644 index 0000000000000..d7d7f46756ef9 --- /dev/null +++ b/airflow/providers/microsoft/azure/example_dags/example_service_bus_subscription.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.providers.microsoft.azure.operators.azure_service_bus_subscription import ( + ASBCreateSubscriptionOperator, + ASBDeleteSubscriptionOperator, + ASBReceiveSubscriptionMessageOperator, + ASBUpdateSubscriptionOperator, +) + +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) + +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "azure_service_bus_conn_id": "azure_service_bus_default", +} + +CLIENT_ID = os.getenv("CLIENT_ID", "") +TOPIC_NAME = "sb_mgmt_topic_test" +SUBSCRIPTION_NAME = "sb_mgmt_subscription" + +with DAG( + dag_id="example_azure_service_bus_subscription", + start_date=datetime(2021, 8, 13), + schedule_interval=None, + catchup=False, + default_args=default_args, + tags=["example", "Azure service bus subscription"], +) as dag: + # [START howto_operator_create_service_bus_subscription] + create_service_bus_subscription = ASBCreateSubscriptionOperator( + task_id="create_service_bus_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + ) + # [END howto_operator_create_service_bus_subscription] + + # [START howto_operator_update_service_bus_subscription] + update_service_bus_subscription = ASBUpdateSubscriptionOperator( + task_id="update_service_bus_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + max_delivery_count=5, + ) + # [END howto_operator_update_service_bus_subscription] + + # [START howto_operator_receive_message_service_bus_subscription] + receive_message_service_bus_subscription = ASBReceiveSubscriptionMessageOperator( + task_id="receive_message_service_bus_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + max_message_count=10, + ) + # [END howto_operator_receive_message_service_bus_subscription] + + # [START howto_operator_delete_service_bus_subscription] + delete_service_bus_subscription = ASBDeleteSubscriptionOperator( + task_id="delete_service_bus_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + trigger_rule="all_done", + ) + # [END howto_operator_delete_service_bus_subscription] + + ( + create_service_bus_subscription + >> update_service_bus_subscription + >> receive_message_service_bus_subscription + >> delete_service_bus_subscription + ) diff --git a/airflow/providers/microsoft/azure/hooks/asb_admin_client.py b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py index f56b596800996..afd1eb8e45d98 100644 --- a/airflow/providers/microsoft/azure/hooks/asb_admin_client.py +++ b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py @@ -15,7 +15,15 @@ # specific language governing permissions and limitations # under the License. -from azure.servicebus.management import QueueProperties, ServiceBusAdministrationClient +import datetime +from typing import Optional, Union + +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.servicebus.management import ( + QueueProperties, + ServiceBusAdministrationClient, + SubscriptionProperties, +) from airflow.exceptions import AirflowBadRequest, AirflowException from airflow.providers.microsoft.azure.hooks.base_asb import BaseAzureServiceBusHook @@ -88,3 +96,119 @@ def delete_queue(self, queue_name: str) -> None: service_mgmt_conn.delete_queue(queue_name) except Exception as e: raise AirflowException(e) + + def create_subscription( + self, + subscription_name: str, + topic_name: str, + lock_duration: Optional[Union[datetime.timedelta, str]] = None, + requires_session: Optional[bool] = None, + default_message_time_to_live: Optional[Union[datetime.timedelta, str]] = None, + dl_on_message_expiration: Optional[bool] = None, + dl_on_filter_evaluation_exceptions: Optional[bool] = None, + max_delivery_count: Optional[int] = None, + enable_batched_operations: Optional[bool] = None, + forward_to: Optional[str] = None, + user_metadata: Optional[str] = None, + forward_dead_lettered_messages_to: Optional[str] = None, + auto_delete_on_idle: Optional[Union[datetime.timedelta, str]] = None, + ) -> SubscriptionProperties: + """ + Create a topic subscription entities under a ServiceBus Namespace. + + :param subscription_name: The subscription that will own the to-be-created rule. + :param topic_name: The topic that will own the to-be-created subscription rule. + :param lock_duration: ISO 8601 timespan duration of a peek-lock; that is, the amount of time + that the message is locked for other receivers. The maximum value for LockDuration is 5 + minutes; the default value is 1 minute. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format like + "PT300S" is accepted. + :param requires_session: A value that indicates whether the queue supports the concept of + sessions. + :param default_message_time_to_live: ISO 8601 default message timespan to live value. This is + the duration after which the message expires, starting from when the message is sent to Service + Bus. This is the default value used when TimeToLive is not set on a message itself. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format like + "PT300S" is accepted. + :param dl_on_message_expiration: A value that indicates whether this subscription + has dead letter support when a message expires. + :param dl_on_filter_evaluation_exceptions: A value that indicates whether this + subscription has dead letter support when a message expires. + :param max_delivery_count: The maximum delivery count. A message is automatically deadlettered + after this number of deliveries. Default value is 10. + :param enable_batched_operations: Value that indicates whether server-side batched operations + are enabled. + :param forward_to: The name of the recipient entity to which all the messages sent to the + subscription are forwarded to. + :param user_metadata: Metadata associated with the subscription. Maximum number of characters + is 1024. + :param forward_dead_lettered_messages_to: The name of the recipient entity to which all the + messages sent to the subscription are forwarded to. + :param auto_delete_on_idle: ISO 8601 timeSpan idle interval after which the subscription is + automatically deleted. The minimum duration is 5 minutes. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format like + "PT300S" is accepted. + """ + if subscription_name is None: + raise AirflowBadRequest("Subscription name cannot be None.") + if topic_name is None: + raise AirflowBadRequest("Topic name cannot be None.") + try: + with self.get_conn() as service_mgmt_conn: + subscription = service_mgmt_conn.create_subscription( + topic_name=topic_name, + subscription_name=subscription_name, + lock_duration=lock_duration, + requires_session=requires_session, + default_message_time_to_live=default_message_time_to_live, + dead_lettering_on_message_expiration=dl_on_message_expiration, + dead_lettering_on_filter_evaluation_exceptions=dl_on_filter_evaluation_exceptions, + max_delivery_count=max_delivery_count, + enable_batched_operations=enable_batched_operations, + forward_to=forward_to, + user_metadata=user_metadata, + forward_dead_lettered_messages_to=forward_dead_lettered_messages_to, + auto_delete_on_idle=auto_delete_on_idle, + ) + return subscription + except ResourceExistsError as e: + raise e + + def delete_subscription(self, subscription_name: str, topic_name: str) -> None: + """ + Delete a topic subscription entities under a ServiceBus Namespace + + :param subscription_name: The subscription name that will own the rule in topic + :param topic_name: The topic that will own the subscription rule. + """ + if subscription_name is None: + raise AirflowBadRequest("Subscription name cannot be None.") + if topic_name is None: + raise AirflowBadRequest("Topic name cannot be None.") + try: + with self.get_conn() as service_mgmt_conn: + service_mgmt_conn.delete_subscription(topic_name, subscription_name) + except ResourceNotFoundError as e: + raise e + + def update_subscription( + self, + subscription_name: str, + topic_name: str, + max_delivery_count: Optional[int], + dl_on_message_expiration: Optional[bool], + enable_batched_operations: Optional[bool], + ) -> None: + with self.get_conn() as service_mgmt_conn: + try: + subscription_prop = service_mgmt_conn.get_subscription(topic_name, subscription_name) + if max_delivery_count: + subscription_prop.max_delivery_count = max_delivery_count + if dl_on_message_expiration is not None: + subscription_prop.dead_lettering_on_message_expiration = dl_on_message_expiration + if enable_batched_operations is not None: + subscription_prop.enable_batched_operations = enable_batched_operations + # update by updating the properties in the model + service_mgmt_conn.update_subscription(topic_name, subscription_prop) + except ResourceNotFoundError as e: + raise e diff --git a/airflow/providers/microsoft/azure/hooks/asb_message.py b/airflow/providers/microsoft/azure/hooks/asb_message.py index b6f8e8931944c..6433bc89e42ca 100644 --- a/airflow/providers/microsoft/azure/hooks/asb_message.py +++ b/airflow/providers/microsoft/azure/hooks/asb_message.py @@ -120,3 +120,43 @@ def receive_message( receiver.complete_message(msg) except Exception as e: raise AirflowException(e) + + def receive_subscription_message( + self, + topic_name: str, + subscription_name: str, + max_message_count: Optional[int], + max_wait_time: Optional[float], + ): + """ + Receive a batch of subscription message at once. This approach is optimal if you wish + to process multiple messages simultaneously, or perform an ad-hoc receive as a single call. + + :param subscription_name: The subscription name that will own the rule in topic + :param topic_name: The topic that will own the subscription rule. + :param max_message_count: Maximum number of messages in the batch. + Actual number returned will depend on prefetch_count and incoming stream rate. + Setting to None will fully depend on the prefetch config. The default value is 1. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. If no + messages arrive, and no timeout is specified, this call will not return until the + connection is closed. If specified, an no messages arrive within the timeout period, + an empty list will be returned. + """ + if subscription_name is None: + raise AirflowBadRequest("Subscription name cannot be None.") + if topic_name is None: + raise AirflowBadRequest("Topic name cannot be None.") + try: + with self.get_conn() as service_bus_client: + subscription_receiver = service_bus_client.get_subscription_receiver( + topic_name, subscription_name + ) + with subscription_receiver: + received_msgs = subscription_receiver.receive_messages( + max_message_count=max_message_count, max_wait_time=max_wait_time + ) + for msg in received_msgs: + self.log.info(msg) + subscription_receiver.complete_message(msg) + except Exception as e: + raise AirflowException(e) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py index 6521ed64e7699..1993ba94d8d75 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py @@ -102,3 +102,137 @@ def test_delete_queue_exception(self, mock_sb_admin_client): hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) with pytest.raises(AirflowException): hook.delete_queue(None) + + @pytest.mark.parametrize( + ( + "mock_subscription_name, mock_topic_name, mock_max_delivery ,mock_dead_letter, " + "mock_enable_batched_operations" + ), + [ + ("subscription_1", "topic_1", 10, True, True), + ("subscription_1", "topic_1", None, None, None), + ("subscription_1", "topic_1", 10, None, None), + ("subscription_1", "topic_1", None, True, None), + ("subscription_1", "topic_1", None, None, True), + ], + ) + @mock.patch('azure.servicebus.management.SubscriptionProperties') + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_create_subscription( + self, + mock_sb_admin_client, + mock_subscription_properties, + mock_subscription_name, + mock_topic_name, + mock_max_delivery, + mock_dead_letter, + mock_enable_batched_operations, + ): + """Test create subscription by mocking the admin client, subscription_name , topic_name""" + mock_subscription_properties.name = "test_subscription" + mock_sb_admin_client.return_value.__enter__.return_value.create_subscription.return_value = ( + mock_subscription_properties + ) + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + response = hook.create_subscription( + mock_subscription_name, + mock_topic_name, + mock_max_delivery, + mock_dead_letter, + mock_enable_batched_operations, + ) + assert response == mock_subscription_properties + + @pytest.mark.parametrize( + "mock_subscription_name, mock_topic_name", + [("subscription_1", None), (None, "topic_1")], + ) + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_create_subscription_exception( + self, mock_sb_admin_client, mock_subscription_name, mock_topic_name + ): + """ + Test `create_subscription` functionality to raise AirflowException, + by passing subscription name and topic name as None and pytest raise Airflow Exception + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.create_subscription(mock_subscription_name, mock_topic_name) + + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_delete_subscription(self, mock_sb_admin_client): + """ + Test Delete subscription functionality by passing subscription name and topic name, + assert the function with values, mock the azure service bus function `delete_subscription` + """ + subscription_name = "test_subscription_name" + topic_name = "test_topic_name" + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + hook.delete_subscription(subscription_name, topic_name) + expected_calls = [mock.call().__enter__().delete_subscription(topic_name, subscription_name)] + mock_sb_admin_client.assert_has_calls(expected_calls) + + @pytest.mark.parametrize( + "mock_subscription_name, mock_topic_name", + [("subscription_1", None), (None, "topic_1")], + ) + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_delete_subscription_exception( + self, mock_sb_admin_client, mock_subscription_name, mock_topic_name + ): + """ + Test `delete_subscription` functionality to raise AirflowException, + by passing subscription name and topic name as None and pytest raise Airflow Exception + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.delete_subscription(mock_subscription_name, mock_topic_name) + + @pytest.mark.parametrize( + ( + "mock_subscription_name, mock_topic_name, mock_max_delivery ,mock_dead_letter, " + "mock_enable_batched_operations" + ), + [ + ("subscription_1", "topic_1", 10, True, True), + ("subscription_1", "topic_1", None, None, None), + ("subscription_1", "topic_1", 10, None, None), + ("subscription_1", "topic_1", None, True, None), + ("subscription_1", "topic_1", None, None, True), + ], + ) + @mock.patch('azure.servicebus.management.SubscriptionProperties') + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_update_subscription( + self, + mock_sb_admin_client, + mock_subscription_properties, + mock_subscription_name, + mock_topic_name, + mock_max_delivery, + mock_dead_letter, + mock_enable_batched_operations, + ): + """Test update subscription by mocking the admin client, and other details""" + mock_subscription_properties.name = mock_subscription_name + mock_sb_admin_client.return_value.__enter__.return_value.get_subscription.return_value = ( + mock_subscription_properties + ) + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + hook.update_subscription( + mock_subscription_name, + mock_topic_name, + mock_max_delivery, + mock_dead_letter, + mock_enable_batched_operations, + ) + expected_calls = [ + mock.call().__enter__().update_subscription(mock_topic_name, mock_subscription_properties) + ] + mock_sb_admin_client.assert_has_calls(expected_calls) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py index bc1d93c366009..c3e7202cf76b1 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_message.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -155,3 +155,46 @@ def test_receive_message_exception(self, mock_sb_client): hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) with pytest.raises(AirflowException): hook.receive_message(None) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_subscription_message(self, mock_sb_client): + """ + Test `receive_subscription_message` hook function and assert the function with mock value, + mock the azure service bus `receive_message` function of subscription + """ + subscription_name = "subscription_1" + topic_name = "topic_name" + max_message_count = 10 + max_wait_time = 5 + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + hook.receive_subscription_message(topic_name, subscription_name, max_message_count, max_wait_time) + expected_calls = [ + mock.call() + .__enter__() + .get_subscription_receiver(subscription_name, topic_name) + .__enter__() + .receive_messages(max_message_count=max_message_count, max_wait_time=max_wait_time) + .get_queue_receiver(self.queue_name) + .__exit__() + .mock_call() + .__exit__ + ] + mock_sb_client.assert_has_calls(expected_calls) + + @pytest.mark.parametrize( + "mock_subscription_name, mock_topic_name, mock_max_count, mock_wait_time", + [("subscription_1", None, None, None), (None, "topic_1", None, None)], + ) + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_subscription_message_exception( + self, mock_sb_client, mock_subscription_name, mock_topic_name, mock_max_count, mock_wait_time + ): + """ + Test `receive_subscription_message` hook function to raise exception + by sending the subscription and topic name as none + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(AirflowException): + hook.receive_subscription_message( + mock_subscription_name, mock_topic_name, mock_max_count, mock_wait_time + ) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_subscription.py b/tests/providers/microsoft/azure/operators/test_azure_service_subscription.py new file mode 100644 index 0000000000000..acfb9999a7913 --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_azure_service_subscription.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import unittest +from unittest import mock + +from airflow.models.dag import DAG +from airflow.providers.microsoft.azure.operators.azure_service_bus_subscription import ( + ASBCreateSubscriptionOperator, + ASBDeleteSubscriptionOperator, + ASBReceiveSubscriptionMessageOperator, + ASBUpdateSubscriptionOperator, +) + +OWNER_NAME = "airflow" +DAG_ID = "test_azure_service_bus_subscription" +TOPIC_NAME = "sb_mgmt_topic_test" +SUBSCRIPTION_NAME = "sb_mgmt_subscription" + + +class TestASBCreateSubscriptionOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating ASBCreateSubscriptionOperator with task id, subscription name, topic name and + asserting with value + """ + asb_create_subscription = ASBCreateSubscriptionOperator( + task_id="asb_create_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + ) + assert asb_create_subscription.task_id == "asb_create_subscription" + assert asb_create_subscription.subscription_name == SUBSCRIPTION_NAME + assert asb_create_subscription.topic_name == TOPIC_NAME + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_create_subscription(self, mock_get_conn): + """ + Test ASBCreateSubscriptionOperator passed with the subscription name, topic name + mocking the connection details, hook create_subscription function + """ + asb_create_subscription = ASBCreateSubscriptionOperator( + task_id="create_service_bus_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + dag=self.dag, + ) + asb_create_subscription.execute(None) + mock_get_conn.return_value.__enter__.return_value.create_subscription.assert_called_once_with( + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + lock_duration=None, + requires_session=None, + default_message_time_to_live=None, + dead_lettering_on_message_expiration=True, + dead_lettering_on_filter_evaluation_exceptions=None, + max_delivery_count=10, + enable_batched_operations=True, + forward_to=None, + user_metadata=None, + forward_dead_lettered_messages_to=None, + auto_delete_on_idle=None, + ) + + +class TestASBDeleteSubscriptionOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating ASBDeleteSubscriptionOperator with task id, subscription name, topic name and + asserting with values + """ + asb_delete_subscription_operator = ASBDeleteSubscriptionOperator( + task_id="asb_delete_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + dag=self.dag, + ) + assert asb_delete_subscription_operator.task_id == "asb_delete_subscription" + assert asb_delete_subscription_operator.topic_name == TOPIC_NAME + assert asb_delete_subscription_operator.subscription_name == SUBSCRIPTION_NAME + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_delete_subscription(self, mock_get_conn): + """ + Test ASBDeleteSubscriptionOperator by mocking subscription name, topic name and + connection and hook delete_subscription + """ + asb_delete_subscription_operator = ASBDeleteSubscriptionOperator( + task_id="asb_delete_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + dag=self.dag, + ) + asb_delete_subscription_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.delete_subscription.assert_called_once_with( + TOPIC_NAME, SUBSCRIPTION_NAME + ) + + +class TestASBUpdateSubscriptionOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating ASBUpdateSubscriptionOperator with task id, subscription name, topic name and + asserting with values + """ + asb_update_subscription_operator = ASBUpdateSubscriptionOperator( + task_id="asb_update_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + max_delivery_count=10, + dag=self.dag, + ) + assert asb_update_subscription_operator.task_id == "asb_update_subscription" + assert asb_update_subscription_operator.topic_name == TOPIC_NAME + assert asb_update_subscription_operator.subscription_name == SUBSCRIPTION_NAME + assert asb_update_subscription_operator.max_delivery_count == 10 + + @mock.patch('azure.servicebus.management.SubscriptionProperties') + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_update_subscription(self, mock_get_conn, mock_subscription_properties): + """ + Test ASBUpdateSubscriptionOperator passed with the subscription name, topic name + mocking the connection details, hook update_subscription function + """ + mock_subscription_properties.name = SUBSCRIPTION_NAME + mock_subscription_properties.max_delivery_count = 20 + mock_get_conn.return_value.__enter__.return_value.get_subscription.return_value = ( + mock_subscription_properties + ) + asb_update_subscription = ASBUpdateSubscriptionOperator( + task_id="asb_update_subscription", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + max_delivery_count=20, + dag=self.dag, + ) + asb_update_subscription.execute(None) + expected_calls = [ + mock.call().__enter__().update_subscription(TOPIC_NAME, mock_subscription_properties) + ] + mock_get_conn.assert_has_calls(expected_calls) + + +class TestASBSubscriptionReceiveMessageOperator(unittest.TestCase): + def setUp(self): + args = {'owner': OWNER_NAME, 'start_date': datetime.datetime(2017, 1, 1)} + self.dag = DAG(DAG_ID, default_args=args) + + def test_init(self): + """ + Test init by creating AzureServiceBusReceiveMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + + asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator( + task_id="asb_subscription_receive_message", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + max_message_count=10, + ) + assert asb_subscription_receive_message.task_id == "asb_subscription_receive_message" + assert asb_subscription_receive_message.topic_name == TOPIC_NAME + assert asb_subscription_receive_message.subscription_name == SUBSCRIPTION_NAME + assert asb_subscription_receive_message.max_message_count == 10 + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn") + def test_receive_message_queue(self, mock_get_conn): + """ + Test ASBReceiveSubscriptionMessageOperator by mock connection, values + and the service bus receive message + """ + asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator( + task_id="asb_subscription_receive_message", + topic_name=TOPIC_NAME, + subscription_name=SUBSCRIPTION_NAME, + max_message_count=10, + ) + asb_subscription_receive_message.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_subscription_receiver(SUBSCRIPTION_NAME, TOPIC_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .get_queue_receiver(self.queue_name) + .__exit__() + .mock_call() + .__exit__ + ] + mock_get_conn.assert_has_calls(expected_calls) From a2e23b5e5057d887c66672a0abe4ff9dd53bbd67 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Wed, 1 Jun 2022 11:32:38 +0530 Subject: [PATCH 10/13] Add Operator's --- .../azure_service_bus_subscription.py | 268 ++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 airflow/providers/microsoft/azure/operators/azure_service_bus_subscription.py diff --git a/airflow/providers/microsoft/azure/operators/azure_service_bus_subscription.py b/airflow/providers/microsoft/azure/operators/azure_service_bus_subscription.py new file mode 100644 index 0000000000000..602d5a45da7e6 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/azure_service_bus_subscription.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +from typing import TYPE_CHECKING, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.asb_admin_client import AzureServiceBusAdminClientHook +from airflow.providers.microsoft.azure.hooks.asb_message import ServiceBusMessageHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class ASBCreateSubscriptionOperator(BaseOperator): + """ + Creates an Azure ServiceBus Topic Subscription under a ServiceBus Namespace + by using ServiceBusAdministrationClient + + :param topic_name: The topic that will own the to-be-created subscription. + :param subscription_name: Name of the subscription that need to be created + :param lock_duration: ISO 8601 timespan duration of a peek-lock; that is, the amount of time + that the message is locked for other receivers. The maximum value for LockDuration is 5 + minutes; the default value is 1 minute. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format like + "PT300S" is accepted. + :param requires_session: A value that indicates whether the queue supports the concept of + sessions. + :param default_message_time_to_live: ISO 8601 default message timespan to live value. This is + the duration after which the message expires, starting from when the message is sent to Service + Bus. This is the default value used when TimeToLive is not set on a message itself. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format like + "PT300S" is accepted. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription + has dead letter support when a message expires. + :param dead_lettering_on_filter_evaluation_exceptions: A value that indicates whether this + subscription has dead letter support when a message expires. + :param max_delivery_count: The maximum delivery count. A message is automatically deadlettered + after this number of deliveries. Default value is 10. + :param enable_batched_operations: Value that indicates whether server-side batched operations + are enabled. + :param forward_to: The name of the recipient entity to which all the messages sent to the + subscription are forwarded to. + :param user_metadata: Metadata associated with the subscription. Maximum number of characters + is 1024. + :param forward_dead_lettered_messages_to: The name of the recipient entity to which all the + messages sent to the subscription are forwarded to. + :param auto_delete_on_idle: ISO 8601 timeSpan idle interval after which the subscription is + automatically deleted. The minimum duration is 5 minutes. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format like + "PT300S" is accepted. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + lock_duration: Optional[Union[datetime.timedelta, str]] = None, + requires_session: Optional[bool] = None, + default_message_time_to_live: Optional[Union[datetime.timedelta, str]] = None, + dead_lettering_on_message_expiration: Optional[bool] = True, + dead_lettering_on_filter_evaluation_exceptions: Optional[bool] = None, + max_delivery_count: Optional[int] = 10, + enable_batched_operations: Optional[bool] = True, + forward_to: Optional[str] = None, + user_metadata: Optional[str] = None, + forward_dead_lettered_messages_to: Optional[str] = None, + auto_delete_on_idle: Optional[Union[datetime.timedelta, str]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.lock_duration = lock_duration + self.requires_session = requires_session + self.default_message_time_to_live = default_message_time_to_live + self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + self.dead_lettering_on_filter_evaluation_exceptions = dead_lettering_on_filter_evaluation_exceptions + self.max_delivery_count = max_delivery_count + self.enable_batched_operations = enable_batched_operations + self.forward_to = forward_to + self.user_metadata = user_metadata + self.forward_dead_lettered_messages_to = forward_dead_lettered_messages_to + self.auto_delete_on_idle = auto_delete_on_idle + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Creates Subscription in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # create subscription with name + subscription = hook.create_subscription( + self.subscription_name, + self.topic_name, + self.lock_duration, + self.requires_session, + self.default_message_time_to_live, + self.dead_lettering_on_message_expiration, + self.dead_lettering_on_filter_evaluation_exceptions, + self.max_delivery_count, + self.enable_batched_operations, + self.forward_to, + self.user_metadata, + self.forward_dead_lettered_messages_to, + self.auto_delete_on_idle, + ) + self.log.info("Created Queue %s", subscription.name) + self.log.info(subscription) + + +class ASBUpdateSubscriptionOperator(BaseOperator): + """ + Update an Azure ServiceBus Topic Subscription under a ServiceBus Namespace + by using ServiceBusAdministrationClient + + :param topic_name: The topic that will own the to-be-created subscription. + :param subscription_name: Name of the subscription that need to be created. + :param max_delivery_count: The maximum delivery count. A message is automatically dead lettered + after this number of deliveries. Default value is 10. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription + has dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched operations are enabled. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + max_delivery_count: Optional[int] = None, + dead_lettering_on_message_expiration: Optional[bool] = None, + enable_batched_operations: Optional[bool] = None, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.max_delivery_count = max_delivery_count + self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + self.enable_batched_operations = enable_batched_operations + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Updates Subscription properties, by connecting to Service Bus Admin client""" + # Create the hook + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Update subscription with name + hook.update_subscription( + self.subscription_name, + self.topic_name, + self.max_delivery_count, + self.dead_lettering_on_message_expiration, + self.enable_batched_operations, + ) + + +class ASBReceiveSubscriptionMessageOperator(BaseOperator): + """ + Receive a messages for the specific subscription under the topic. + + :param subscription_name: The subscription name that will own the rule in topic + :param topic_name: The topic that will own the subscription rule. + :param max_message_count: Maximum number of messages in the batch. + Actual number returned will depend on prefetch_count and incoming stream rate. + Setting to None will fully depend on the prefetch config. The default value is 1. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. If no + messages arrive, and no timeout is specified, this call will not return until the + connection is closed. If specified, an no messages arrive within the timeout period, + an empty list will be returned. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + max_message_count: Optional[int] = 1, + max_wait_time: Optional[float] = 5, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.max_message_count = max_message_count + self.max_wait_time = max_wait_time + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """ + Receive Message in specific queue in Service Bus namespace, + by connecting to Service Bus client + """ + # Create the hook + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Receive message + hook.receive_subscription_message( + self.topic_name, self.subscription_name, self.max_message_count, self.max_wait_time + ) + + +class ASBDeleteSubscriptionOperator(BaseOperator): + """ + Deletes the topic subscription in the Azure ServiceBus namespace + + :param topic_name: The topic that will own the to-be-created subscription. + :param subscription_name: Name of the subscription that need to be created + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Delete topic subscription in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # delete subscription with name + hook.delete_subscription(self.subscription_name, self.topic_name) From 87926a944d188b85ef29e7acefd61d4463e8cadb Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Wed, 1 Jun 2022 13:54:31 +0530 Subject: [PATCH 11/13] Add azure_service_bus_subscription operator in provider.yml file --- airflow/providers/microsoft/azure/provider.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 805fed9ab973c..357795aa75c14 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -123,6 +123,7 @@ operators: - integration-name: Microsoft Azure Service Bus python-modules: - airflow.providers.microsoft.azure.operators.azure_service_bus_queue + - airflow.providers.microsoft.azure.operators.azure_service_bus_subscription sensors: - integration-name: Microsoft Azure Cosmos DB From 9b698e9968c94b919c277b09104479c69cac7329 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Tue, 31 May 2022 14:07:58 +0530 Subject: [PATCH 12/13] Implement Azure Service Bus Queue Operator's Implement Azure Service Bus Queue Operator's - Test case - Doc string - Removed airflow exception and added proper error type - Removed try/expect block - Renamed the azure service bus queue operator test file Implement Azure Service Bus Queue Operator's Add Azure Service Bus Queue Operator's - Create Queue Operator - Send message queue Operator - Receive message queue Operator - Delete queue Operator - example DAG - Added hooks and connection type in - provider yaml file - Added unit Test case --- .../example_dags/example_service_bus_queue.py | 105 +++++++++ .../microsoft/azure/hooks/asb_admin_client.py | 83 ++++++++ .../microsoft/azure/hooks/asb_message.py | 109 ++++++++++ .../microsoft/azure/hooks/base_asb.py | 73 +++++++ .../operators/azure_service_bus_queue.py | 182 ++++++++++++++++ .../providers/microsoft/azure/provider.yaml | 15 ++ .../connections/azure_service_bus.rst | 64 ++++++ docs/integration-logos/azure/Service-Bus.svg | 1 + setup.py | 1 + .../azure/hooks/test_asb_admin_client.py | 96 +++++++++ .../microsoft/azure/hooks/test_asb_message.py | 164 +++++++++++++++ .../operators/test_azure_service_bus_queue.py | 199 ++++++++++++++++++ 12 files changed, 1092 insertions(+) create mode 100644 airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py create mode 100644 airflow/providers/microsoft/azure/hooks/asb_admin_client.py create mode 100644 airflow/providers/microsoft/azure/hooks/asb_message.py create mode 100644 airflow/providers/microsoft/azure/hooks/base_asb.py create mode 100644 airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py create mode 100644 docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst create mode 100644 docs/integration-logos/azure/Service-Bus.svg create mode 100644 tests/providers/microsoft/azure/hooks/test_asb_admin_client.py create mode 100644 tests/providers/microsoft/azure/hooks/test_asb_message.py create mode 100644 tests/providers/microsoft/azure/operators/test_azure_service_bus_queue.py diff --git a/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py b/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py new file mode 100644 index 0000000000000..da689b82e1628 --- /dev/null +++ b/airflow/providers/microsoft/azure/example_dags/example_service_bus_queue.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.providers.microsoft.azure.operators.azure_service_bus_queue import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) + +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "azure_service_bus_conn_id": "azure_service_bus_default", +} + +CLIENT_ID = os.getenv("CLIENT_ID", "") +QUEUE_NAME = "sb_mgmt_queue_test" +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + +with DAG( + dag_id="example_azure_service_bus_queue", + start_date=datetime(2021, 8, 13), + schedule_interval=None, + catchup=False, + default_args=default_args, + tags=["example", "Azure service bus"], +) as dag: + # [START howto_operator_create_service_bus_queue] + create_service_bus_queue = AzureServiceBusCreateQueueOperator( + task_id="create_service_bus_queue", + queue_name=QUEUE_NAME, + ) + # [END howto_operator_create_service_bus_queue] + + # [START howto_operator_send_message_to_service_bus_queue] + send_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_message_to_service_bus_queue", + message=MESSAGE, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_message_to_service_bus_queue] + + # [START howto_operator_send_list_message_to_service_bus_queue] + send_list_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_list_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_list_message_to_service_bus_queue] + + # [START howto_operator_send_batch_message_to_service_bus_queue] + send_batch_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_batch_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=True, + ) + # [END howto_operator_send_batch_message_to_service_bus_queue] + + # [START howto_operator_receive_message_service_bus_queue] + receive_message_service_bus_queue = AzureServiceBusReceiveMessageOperator( + task_id="receive_message_service_bus_queue", + queue_name=QUEUE_NAME, + max_message_count=20, + max_wait_time=5, + ) + # [END howto_operator_receive_message_service_bus_queue] + + # [START howto_operator_delete_service_bus_queue] + delete_service_bus_queue = AzureServiceBusDeleteQueueOperator( + task_id="delete_service_bus_queue", queue_name=QUEUE_NAME, trigger_rule="all_done" + ) + # [END howto_operator_delete_service_bus_queue] + + ( + create_service_bus_queue + >> send_message_to_service_bus_queue + >> send_list_message_to_service_bus_queue + >> send_batch_message_to_service_bus_queue + >> receive_message_service_bus_queue + >> delete_service_bus_queue + ) diff --git a/airflow/providers/microsoft/azure/hooks/asb_admin_client.py b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py new file mode 100644 index 0000000000000..74c6d7e78b249 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from azure.servicebus.management import QueueProperties, ServiceBusAdministrationClient + +from airflow.providers.microsoft.azure.hooks.base_asb import BaseAzureServiceBusHook + + +class AzureServiceBusAdminClientHook(BaseAzureServiceBusHook): + """ + Interacts with Azure ServiceBus management client + and Use this client to create, update, list, and delete resources of a ServiceBus namespace. + it uses the same azure service bus client connection inherits from the base class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def get_conn(self) -> ServiceBusAdministrationClient: + """Create and returns ServiceBusAdministration by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + + self.connection_string = str( + extras.get('connection_string') or extras.get('extra__azure_service_bus__connection_string') + ) + return ServiceBusAdministrationClient.from_connection_string(self.connection_string) + + def create_queue( + self, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + ) -> QueueProperties: + """ + Create Queue by connecting to service Bus Admin client return the QueueProperties + + :param queue_name: The name of the queue or a QueueProperties with name. + :param max_delivery_count: The maximum delivery count. A message is automatically + dead lettered after this number of deliveries. Default value is 10.. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + """ + if queue_name is None: + raise ValueError("Queue name cannot be None.") + + with self.get_conn() as service_mgmt_conn: + queue = service_mgmt_conn.create_queue( + queue_name, + max_delivery_count=max_delivery_count, + dead_lettering_on_message_expiration=dead_lettering_on_message_expiration, + enable_batched_operations=enable_batched_operations, + ) + return queue + + def delete_queue(self, queue_name: str) -> None: + """ + Delete the queue by queue_name in service bus namespace + + :param queue_name: The name of the queue or a QueueProperties with name. + """ + if queue_name is None: + raise ValueError("Queue name cannot be None.") + + with self.get_conn() as service_mgmt_conn: + service_mgmt_conn.delete_queue(queue_name) diff --git a/airflow/providers/microsoft/azure/hooks/asb_message.py b/airflow/providers/microsoft/azure/hooks/asb_message.py new file mode 100644 index 0000000000000..9ab59f08fc583 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb_message.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional, Union + +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusSender + +from airflow.providers.microsoft.azure.hooks.base_asb import BaseAzureServiceBusHook + + +class ServiceBusMessageHook(BaseAzureServiceBusHook): + """ + Interacts with ServiceBusClient and acts as a high level interface + for getting ServiceBusSender and ServiceBusReceiver. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def get_conn(self) -> ServiceBusClient: + """Create and returns ServiceBusClient by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + + self.connection_string = str( + extras.get('connection_string') or extras.get('extra__azure_service_bus__connection_string') + ) + + return ServiceBusClient.from_connection_string(conn_str=self.connection_string, logging_enable=True) + + def send_message( + self, queue_name: str, messages: Union[str, List[str]], batch_message_flag: bool = False + ): + """ + By using ServiceBusClient Send message(s) to a Service Bus Queue. By using + batch_message_flag it enables and send message as batch message + + :param queue_name: The name of the queue or a QueueProperties with name. + :param messages: Message which needs to be sent to the queue. It can be string or list of string. + :param batch_message_flag: bool flag, can be set to True if message needs to be sent as batch message. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + if not messages: + raise ValueError("Messages list cannot be empty.") + with self.get_conn() as service_bus_client, service_bus_client.get_queue_sender( + queue_name=queue_name + ) as sender: + with sender: + if isinstance(messages, str): + if not batch_message_flag: + msg = ServiceBusMessage(messages) + sender.send_messages(msg) + else: + self.send_batch_message(sender, [messages]) + else: + if not batch_message_flag: + self.send_list_messages(sender, messages) + else: + self.send_batch_message(sender, messages) + + @staticmethod + def send_list_messages(sender: ServiceBusSender, messages: List[str]): + list_messages = [ServiceBusMessage(message) for message in messages] + sender.send_messages(list_messages) # type: ignore[arg-type] + + @staticmethod + def send_batch_message(sender: ServiceBusSender, messages: List[str]): + batch_message = sender.create_message_batch() + for message in messages: + batch_message.add_message(ServiceBusMessage(message)) + sender.send_messages(batch_message) + + def receive_message( + self, queue_name, max_message_count: Optional[int] = 1, max_wait_time: Optional[float] = None + ): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue name or a QueueProperties with name. + :param max_message_count: Maximum number of messages in the batch. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. + """ + if queue_name is None: + raise ValueError("Queue name cannot be None.") + + with self.get_conn() as service_bus_client, service_bus_client.get_queue_receiver( + queue_name=queue_name + ) as receiver: + with receiver: + received_msgs = receiver.receive_messages( + max_message_count=max_message_count, max_wait_time=max_wait_time + ) + for msg in received_msgs: + self.log.info(msg) + receiver.complete_message(msg) diff --git a/airflow/providers/microsoft/azure/hooks/base_asb.py b/airflow/providers/microsoft/azure/hooks/base_asb.py new file mode 100644 index 0000000000000..8adbf577e3bf2 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/base_asb.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional + +from airflow.hooks.base import BaseHook + + +class BaseAzureServiceBusHook(BaseHook): + """ + BaseAzureServiceBusHook class to create session and create connection. Client ID and + Secrete IDs are optional. + + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + conn_name_attr = 'azure_service_bus_conn_id' + default_conn_name = 'azure_service_bus_default' + conn_type = 'azure_service_bus' + hook_name = 'Azure ServiceBus' + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__azure_service_bus__connection_string": StringField( + lazy_gettext('Azure Service Bus Connection String'), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'host', 'extra'], + "relabeling": { + 'login': 'Client ID', + 'password': 'Secret', + }, + "placeholders": { + 'login': 'Client ID (Optional)', + 'password': 'Client Secret (Optional)', + 'extra__azure_service_bus__connection_string': 'Azure Service Bus Connection String', + }, + } + + def __init__(self, azure_service_bus_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_service_bus_conn_id + self._conn = None + self.connection_string: Optional[str] = None + + def get_conn(self): + return None diff --git a/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py b/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py new file mode 100644 index 0000000000000..59fdf5d5f5930 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/azure_service_bus_queue.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.asb_admin_client import AzureServiceBusAdminClientHook +from airflow.providers.microsoft.azure.hooks.asb_message import ServiceBusMessageHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AzureServiceBusCreateQueueOperator(BaseOperator): + """ + Creates a Azure ServiceBus queue under a ServiceBus Namespace by using ServiceBusAdministrationClient + + :param queue_name: The name of the queue. should be unique. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.max_delivery_count = max_delivery_count + self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + self.enable_batched_operations = enable_batched_operations + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Creates Queue in Azure Service Bus namespace, by connecting to Service Bus Admin client in hook""" + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # create queue with name + queue = hook.create_queue( + self.queue_name, + self.max_delivery_count, + self.dead_lettering_on_message_expiration, + self.enable_batched_operations, + ) + self.log.info("Created Queue ", queue) + + +class AzureServiceBusSendMessageOperator(BaseOperator): + """ + Send Message or batch message to the service bus queue + + :param message: Message which needs to be sent to the queue. It can be string or list of string. + :param batch: Its boolean flag by default it is set to False, if the message needs to be sent + as batch message it can be set to True. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + message: Union[str, List[str]], + batch: bool = False, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.batch = batch + self.message = message + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """ + Sends Message to the specific queue in Service Bus namespace, by + connecting to Service Bus client + """ + # Create the hook + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # send message + hook.send_message(self.queue_name, self.message, self.batch) + + +class AzureServiceBusReceiveMessageOperator(BaseOperator): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + max_message_count: Optional[int] = 10, + max_wait_time: Optional[float] = 5, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + self.max_message_count = max_message_count + self.max_wait_time = max_wait_time + + def execute(self, context: "Context") -> None: + """ + Receive Message in specific queue in Service Bus namespace, + by connecting to Service Bus client + """ + # Create the hook + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Receive message + hook.receive_message( + self.queue_name, max_message_count=self.max_message_count, max_wait_time=self.max_wait_time + ) + + +class AzureServiceBusDeleteQueueOperator(BaseOperator): + """ + Deletes the Queue in the Azure ServiceBus namespace + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Delete Queue in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # delete queue with name + hook.delete_queue(self.queue_name) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 10cab462b4c53..805fed9ab973c 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -88,6 +88,10 @@ integrations: external-doc-url: https://azure.microsoft.com/ logo: /integration-logos/azure/Microsoft-Azure.png tags: [azure] + - integration-name: Microsoft Azure Service Bus + external-doc-url: https://azure.microsoft.com/en-us/services/service-bus/ + logo: /integration-logos/azure/Service-Bus.svg + tags: [azure] operators: - integration-name: Microsoft Azure Data Lake Storage @@ -116,6 +120,9 @@ operators: - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.operators.data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.operators.azure_service_bus_queue sensors: - integration-name: Microsoft Azure Cosmos DB @@ -167,6 +174,11 @@ hooks: python-modules: - airflow.providers.microsoft.azure.hooks.data_factory - airflow.providers.microsoft.azure.hooks.azure_data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.hooks.base_asb + - airflow.providers.microsoft.azure.hooks.asb_admin_client + - airflow.providers.microsoft.azure.hooks.asb_message transfers: - source-integration-name: Local @@ -203,6 +215,7 @@ hook-class-names: # deprecated - to be removed after providers add dependency o - airflow.providers.microsoft.azure.hooks.wasb.WasbHook - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook + - airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook @@ -229,6 +242,8 @@ connection-types: - hook-class-name: >- airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-type: azure_container_registry + - hook-class-name: airflow.providers.microsoft.azure.hooks.base_asb.BaseAzureServiceBusHook + connection-type: azure_service_bus secrets-backends: - airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst b/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst new file mode 100644 index 0000000000000..6b95b86a567c1 --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/connections/azure_service_bus.rst @@ -0,0 +1,64 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:azure_service_bus: + +Microsoft Azure Service Bus +======================================= + +The Microsoft Azure Service Bus connection type enables the Azure Service Bus Integration. + +Authenticating to Azure Service Bus +------------------------------------ + +There are two ways to authenticate and authorize access to Azure Service Bus resources: +Azure Active Directory (Azure AD) and Shared Access Signatures (SAS). + +1. Use `Azure Active Directory + `_ + i.e. add specific credentials (client_id, secret, tenant) and subscription id to the Airflow connection. +2. Use `Azure Shared access signature + `_ +3. Fallback on `DefaultAzureCredential + `_. + This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI... + +Default Connection IDs +---------------------- + +All hooks and operators related to Microsoft Azure Service Bus use ``azure_service_bus_default`` by default. + +Configuring the Connection +-------------------------- + +Client ID (optional) + Specify the ``client_id`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + It can be left out to fall back on ``DefaultAzureCredential``. + +Secret (optional) + Specify the ``secret`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + It can be left out to fall back on ``DefaultAzureCredential``. + +Connection ID + Specify the Azure Service bus connection string ID used for the initial connection. + `Get connection string + `_ + Use the key ``extra__azure_service_bus__connection_string`` to pass in the Connection ID . diff --git a/docs/integration-logos/azure/Service-Bus.svg b/docs/integration-logos/azure/Service-Bus.svg new file mode 100644 index 0000000000000..1604e04232630 --- /dev/null +++ b/docs/integration-logos/azure/Service-Bus.svg @@ -0,0 +1 @@ + diff --git a/setup.py b/setup.py index 7ce5e61e1a785..922b2e0c54ad6 100644 --- a/setup.py +++ b/setup.py @@ -231,6 +231,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'azure-storage-blob>=12.7.0,<12.9.0', 'azure-storage-common>=2.1.0', 'azure-storage-file>=2.1.0', + 'azure-servicebus>=7.6.1', ] cassandra = [ 'cassandra-driver>=3.13.0', diff --git a/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py new file mode 100644 index 0000000000000..4fb9f4a686001 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_asb_admin_client.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +import pytest +from azure.servicebus.management import ServiceBusAdministrationClient + +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.asb_admin_client import AzureServiceBusAdminClientHook + + +class TestAzureServiceBusAdminClientHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.client_id = "test_client_id" + self.secret_key = "test_client_secret" + self.mock_conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra={"connection_string": self.connection_string}, + ) + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client." + "AzureServiceBusAdminClientHook.get_connection" + ) + def test_get_conn(self, mock_connection): + mock_connection.return_value = self.mock_conn + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusAdministrationClient) + + @mock.patch('azure.servicebus.management.QueueProperties') + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_create_queue(self, mock_sb_admin_client, mock_queue_properties): + """ + Test `create_queue` hook function with mocking connection, queue properties value and + the azure service bus `create_queue` function + """ + mock_queue_properties.name = self.queue_name + mock_sb_admin_client.return_value.__enter__.return_value.create_queue.return_value = ( + mock_queue_properties + ) + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + response = hook.create_queue(self.queue_name) + assert response == mock_queue_properties + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_create_queue_exception(self, mock_sb_admin_client): + """Test `create_queue` functionality to raise ValueError by passing queue name as None""" + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(ValueError): + hook.create_queue(None) + + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn' + ) + def test_delete_queue(self, mock_sb_admin_client): + """ + Test Delete queue functionality by passing queue name, assert the function with values, + mock the azure service bus function `delete_queue` + """ + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + hook.delete_queue(self.queue_name) + expected_calls = [mock.call().__enter__().delete_queue(self.queue_name)] + mock_sb_admin_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_admin_client.ServiceBusAdministrationClient') + def test_delete_queue_exception(self, mock_sb_admin_client): + """Test `delete_queue` functionality to raise ValueError, by passing queue name as None""" + hook = AzureServiceBusAdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(ValueError): + hook.delete_queue(None) diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py new file mode 100644 index 0000000000000..1755465d04062 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from unittest import mock + +import pytest +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusMessageBatch + +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.asb_message import ServiceBusMessageHook + +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + + +class TestServiceBusMessageHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.client_id = "test_client_id" + self.secret_key = "test_client_secret" + self.conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({'connection_string': self.connection_string}), + ) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_connection") + def test_get_service_bus_message_conn(self, mock_connection): + """ + Test get_conn() function and check whether the get_conn() function returns value + is instance of ServiceBusClient + """ + mock_connection.return_value = self.conn + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusClient) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_connection") + def test_get_conn_value_error(self, mock_connection): + """Test get_conn() function and check whether the get_conn() raise Value error""" + mock_connection.return_value = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + login=self.client_id, + password=self.secret_key, + extra=json.dumps({"connection_string": "test connection"}), + ) + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(ValueError): + hook.get_conn() + + @pytest.mark.parametrize( + "mock_message, mock_batch_flag", + [ + (MESSAGE, True), + (MESSAGE, False), + (MESSAGE_LIST, True), + (MESSAGE_LIST, False), + ], + ) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.send_list_messages' + ) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.send_batch_message' + ) + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_send_message( + self, mock_sb_client, mock_batch_message, mock_list_message, mock_message, mock_batch_flag + ): + """ + Test `send_message` hook function with batch flag and message passed as mocked params, + which can be string or list of string, mock the azure service bus `send_messages` function + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id="azure_service_bus_default") + hook.send_message( + queue_name=self.queue_name, messages=mock_message, batch_message_flag=mock_batch_flag + ) + if isinstance(mock_message, list): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = [ServiceBusMessage(msg) for msg in mock_message] + elif isinstance(mock_message, str): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = ServiceBusMessage(mock_message) + + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(self.queue_name) + .__enter__() + .send_messages(message) + .__exit__() + ] + mock_sb_client.assert_has_calls(expected_calls, any_order=False) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_send_message_exception(self, mock_sb_client): + """ + Test `send_message` functionality to raise AirflowException in Azure ServiceBusMessageHook + by passing queue name as None + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(TypeError): + hook.send_message(queue_name=None, messages="", batch_message_flag=False) + + @mock.patch('azure.servicebus.ServiceBusMessage') + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_message(self, mock_sb_client, mock_service_bus_message): + """ + Test `receive_message` hook function and assert the function with mock value, + mock the azure service bus `receive_messages` function + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [ + mock_service_bus_message + ] + hook.receive_message(self.queue_name) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(self.queue_name) + .__enter__() + .receive_messages(max_message_count=30, max_wait_time=5) + .get_queue_receiver(self.queue_name) + .__exit__() + .mock_call() + .__exit__ + ] + mock_sb_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') + def test_receive_message_exception(self, mock_sb_client): + """ + Test `receive_message` functionality to raise AirflowException in Azure ServiceBusMessageHook + by passing queue name as None + """ + hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(ValueError): + hook.receive_message(None) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_bus_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_bus_queue.py new file mode 100644 index 0000000000000..7f837fe0f719e --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_azure_service_bus_queue.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +import pytest +from azure.servicebus import ServiceBusMessage + +from airflow.providers.microsoft.azure.operators.azure_service_bus_queue import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +QUEUE_NAME = "test_queue" +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + + +class TestAzureServiceBusCreateQueueOperator: + @pytest.mark.parametrize( + "mock_dl_msg_expiration, mock_batched_operation", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], + ) + def test_init(self, mock_dl_msg_expiration, mock_batched_operation): + """ + Test init by creating AzureServiceBusCreateQueueOperator with task id, + queue_name and asserting with value + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue", + queue_name=QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=mock_dl_msg_expiration, + enable_batched_operations=mock_batched_operation, + ) + assert asb_create_queue_operator.task_id == "asb_create_queue" + assert asb_create_queue_operator.queue_name == QUEUE_NAME + assert asb_create_queue_operator.max_delivery_count == 10 + assert asb_create_queue_operator.dead_lettering_on_message_expiration is mock_dl_msg_expiration + assert asb_create_queue_operator.enable_batched_operations is mock_batched_operation + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_create_queue(self, mock_get_conn): + """ + Test AzureServiceBusCreateQueueOperator passed with the queue name, + mocking the connection details, hook create_queue function + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue_operator", + queue_name=QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + asb_create_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.create_queue.assert_called_once_with( + QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + + +class TestAzureServiceBusDeleteQueueOperator: + def test_init(self): + """ + Test init by creating AzureServiceBusDeleteQueueOperator with task id, queue_name and asserting + with values + """ + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + ) + assert asb_delete_queue_operator.task_id == "asb_delete_queue" + assert asb_delete_queue_operator.queue_name == QUEUE_NAME + + @mock.patch( + "airflow.providers.microsoft.azure.hooks.asb_admin_client.AzureServiceBusAdminClientHook.get_conn" + ) + def test_delete_queue(self, mock_get_conn): + """Test AzureServiceBusDeleteQueueOperator by mocking queue name, connection and hook delete_queue""" + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + ) + asb_delete_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.delete_queue.assert_called_once_with(QUEUE_NAME) + + +class TestAzureServiceBusSendMessageOperator: + @pytest.mark.parametrize( + "mock_message, mock_batch_flag", + [ + (MESSAGE, True), + (MESSAGE, False), + (MESSAGE_LIST, True), + (MESSAGE_LIST, False), + ], + ) + def test_init(self, mock_message, mock_batch_flag): + """ + Test init by creating AzureServiceBusSendMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue_without_batch", + queue_name=QUEUE_NAME, + message=mock_message, + batch=mock_batch_flag, + ) + assert asb_send_message_queue_operator.task_id == "asb_send_message_queue_without_batch" + assert asb_send_message_queue_operator.queue_name == QUEUE_NAME + assert asb_send_message_queue_operator.message == mock_message + assert asb_send_message_queue_operator.batch is mock_batch_flag + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn") + def test_send_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusSendMessageOperator with queue name, batch boolean flag, mock + the send_messages of azure service bus function + """ + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue", + queue_name=QUEUE_NAME, + message="Test message", + batch=False, + ) + asb_send_message_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(QUEUE_NAME) + .__enter__() + .send_messages(ServiceBusMessage("Test message")) + .__exit__() + ] + mock_get_conn.assert_has_calls(expected_calls, any_order=False) + + +class TestAzureServiceBusReceiveMessageOperator: + def test_init(self): + """ + Test init by creating AzureServiceBusReceiveMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + ) + assert asb_receive_queue_operator.task_id == "asb_receive_message_queue" + assert asb_receive_queue_operator.queue_name == QUEUE_NAME + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn") + def test_receive_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusReceiveMessageOperator by mock connection, values + and the service bus receive message + """ + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + ) + asb_receive_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(QUEUE_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .get_queue_receiver(QUEUE_NAME) + .__exit__() + .mock_call() + .__exit__ + ] + mock_get_conn.assert_has_calls(expected_calls) From f8a33f140c3f097ca2eb9c04c439bf105f71be22 Mon Sep 17 00:00:00 2001 From: bharanidharan14 Date: Fri, 3 Jun 2022 09:32:54 +0530 Subject: [PATCH 13/13] Fix - Fixed doc string, and test case --- airflow/providers/microsoft/azure/hooks/asb_admin_client.py | 5 ++--- tests/providers/microsoft/azure/hooks/test_asb_message.py | 2 +- .../azure/operators/test_azure_service_subscription.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/asb_admin_client.py b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py index afd1eb8e45d98..d2a6753d4b84e 100644 --- a/airflow/providers/microsoft/azure/hooks/asb_admin_client.py +++ b/airflow/providers/microsoft/azure/hooks/asb_admin_client.py @@ -145,9 +145,8 @@ def create_subscription( :param forward_dead_lettered_messages_to: The name of the recipient entity to which all the messages sent to the subscription are forwarded to. :param auto_delete_on_idle: ISO 8601 timeSpan idle interval after which the subscription is - automatically deleted. The minimum duration is 5 minutes. - Input value of either type ~datetime.timedelta or string in ISO 8601 duration format like - "PT300S" is accepted. + automatically deleted. The minimum duration is 5 minutes. Input value of either + type ~datetime.timedelta or string in ISO 8601 duration format like "PT300S" is accepted. """ if subscription_name is None: raise AirflowBadRequest("Subscription name cannot be None.") diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py index c3e7202cf76b1..51d3cebeb47d2 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_message.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -174,7 +174,7 @@ def test_receive_subscription_message(self, mock_sb_client): .get_subscription_receiver(subscription_name, topic_name) .__enter__() .receive_messages(max_message_count=max_message_count, max_wait_time=max_wait_time) - .get_queue_receiver(self.queue_name) + .get_subscription_receiver(subscription_name, topic_name) .__exit__() .mock_call() .__exit__ diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_subscription.py b/tests/providers/microsoft/azure/operators/test_azure_service_subscription.py index acfb9999a7913..808f5382f792a 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_service_subscription.py +++ b/tests/providers/microsoft/azure/operators/test_azure_service_subscription.py @@ -215,7 +215,7 @@ def test_receive_message_queue(self, mock_get_conn): .get_subscription_receiver(SUBSCRIPTION_NAME, TOPIC_NAME) .__enter__() .receive_messages(max_message_count=10, max_wait_time=5) - .get_queue_receiver(self.queue_name) + .get_subscription_receiver(SUBSCRIPTION_NAME, TOPIC_NAME) .__exit__() .mock_call() .__exit__