From e2d9b3b4a831212c646c646e9012daa6e8ceb606 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Thu, 7 May 2026 18:09:45 +0900 Subject: [PATCH] Migrate Azure Batch provider to azure-batch 15.x (track 2 SDK) --- providers/microsoft/azure/docs/changelog.rst | 45 ++++ providers/microsoft/azure/docs/index.rst | 2 +- providers/microsoft/azure/pyproject.toml | 4 +- .../providers/microsoft/azure/hooks/batch.py | 180 ++++++------- .../microsoft/azure/operators/batch.py | 54 ++-- .../unit/microsoft/azure/hooks/test_batch.py | 254 ++++++++---------- .../microsoft/azure/operators/test_batch.py | 143 +++++----- scripts/ci/prek/known_airflow_exceptions.txt | 3 +- 8 files changed, 325 insertions(+), 360 deletions(-) diff --git a/providers/microsoft/azure/docs/changelog.rst b/providers/microsoft/azure/docs/changelog.rst index 011dafe00be96..fe2c8f379c210 100644 --- a/providers/microsoft/azure/docs/changelog.rst +++ b/providers/microsoft/azure/docs/changelog.rst @@ -27,6 +27,51 @@ Changelog --------- +14.0.0 +...... + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + The Azure Batch hook and operator have been migrated to ``azure-batch>=15.0.0`` + (Azure SDK "track 2"). Users must update their code if they depend on Azure Batch + model classes or extend ``AzureBatchHook`` / ``AzureBatchOperator``. See the + `azure-batch CHANGELOG `__ + for the full list of upstream renames and removals. + +* The ``azure-batch`` dependency floor has been raised to ``>=15.0.0``. The 15.x line + is a full rewrite that removes ``BatchServiceClient`` (replaced by + ``BatchClient``), ``batch_auth.SharedKeyCredentials`` (replaced by + ``azure.core.credentials.AzureNamedKeyCredential``), and many model classes have + been renamed. +* ``AzureBatchOperator`` constructor type hints updated to the new model names: + ``JobManagerTask`` -> ``BatchJobManagerTask``, + ``JobPreparationTask`` -> ``BatchJobPreparationTask``, + ``JobReleaseTask`` -> ``BatchJobReleaseTask``, + ``TaskContainerSettings`` -> ``BatchTaskContainerSettings``, + ``StartTask`` -> ``BatchStartTask``. Update Dag code that constructs and passes + these models accordingly. +* The ``os_family`` and ``os_version`` parameters have been removed from + ``AzureBatchHook.configure_pool`` and ``AzureBatchOperator``. Cloud Service pool + configuration is no longer supported by Azure Batch (azure-batch 15.x removed + ``CloudServiceConfiguration``); passing either name as a keyword argument now + raises ``ValueError``. Use ``vm_publisher`` / ``vm_offer`` / ``vm_sku`` / + ``vm_node_agent_sku_id`` instead. +* ``AzureBatchOperator.batch_max_retries`` is currently a no-op. The previous + implementation assigned an integer to ``BatchServiceClient.config.retry_policy``, + which expected a policy object; the new client manages retries via its own + pipeline. The parameter is retained to avoid signature breakage and may be wired + to the new client's retry kwargs in a future release. +* Pool / job / task deletion and job termination now go through the new + long-running operation surface (``begin_delete_pool``, ``begin_delete_job``, + ``begin_terminate_job``). This is transparent to Dag authors using the operator + but affects subclasses that called ``connection.pool.delete`` / ``job.delete`` / + ``job.terminate`` directly. +* ``AzureBatchHook`` no longer routes Azure Identity credentials through + ``AzureIdentityCredentialAdapter``. The new ``BatchClient`` accepts + ``azure-identity`` ``TokenCredential`` instances natively. + 13.2.0 ...... diff --git a/providers/microsoft/azure/docs/index.rst b/providers/microsoft/azure/docs/index.rst index c86007b64f9cd..6ec167e0f436e 100644 --- a/providers/microsoft/azure/docs/index.rst +++ b/providers/microsoft/azure/docs/index.rst @@ -110,7 +110,7 @@ PIP package Version required ``apache-airflow`` ``>=2.11.0`` ``apache-airflow-providers-common-compat`` ``>=1.13.0`` ``adlfs`` ``>=2023.10.0`` -``azure-batch`` ``<15.0.0,>=8.0.0`` +``azure-batch`` ``>=15.0.0`` ``azure-cosmos`` ``>=4.6.0`` ``azure-mgmt-cosmosdb`` ``>=3.0.0`` ``azure-datalake-store`` ``>=0.0.45`` diff --git a/providers/microsoft/azure/pyproject.toml b/providers/microsoft/azure/pyproject.toml index e7df3848d6fca..3977480395e9d 100644 --- a/providers/microsoft/azure/pyproject.toml +++ b/providers/microsoft/azure/pyproject.toml @@ -62,9 +62,7 @@ dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.13.0", "adlfs>=2023.10.0", - # azure-batch 15.x is a full rewrite of the Azure SDK (track 2) that removes BatchServiceClient, batch_auth, - # and the other references in AzureBatchHook. Lifting the upper bound cap needs a full hook rewrite. - "azure-batch>=8.0.0,<15.0.0", + "azure-batch>=15.0.0", "azure-cosmos>=4.6.0", "azure-mgmt-cosmosdb>=3.0.0", "azure-datalake-store>=0.0.45", diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py index 06f0cf8710fa2..fcf136d2c6290 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py @@ -22,18 +22,24 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from azure.batch import BatchServiceClient, batch_auth, models as batch_models +from azure.batch import BatchClient, models as batch_models +from azure.core.credentials import AzureNamedKeyCredential, TokenCredential +from azure.core.exceptions import HttpResponseError, ResourceExistsError -from airflow.providers.common.compat.sdk import AirflowException, BaseHook +from airflow.providers.common.compat.sdk import BaseHook from airflow.providers.microsoft.azure.utils import ( - AzureIdentityCredentialAdapter, add_managed_identity_connection_widgets, get_field, + get_sync_default_azure_credential, ) from airflow.utils import timezone if TYPE_CHECKING: - from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter + from azure.batch.models import ( + BatchJobCreateOptions, + BatchPoolCreateOptions, + BatchTaskCreateOptions, + ) class AzureBatchHook(BaseHook): @@ -85,11 +91,11 @@ def _get_field(self, extras, name): ) @cached_property - def connection(self) -> BatchServiceClient: + def connection(self) -> BatchClient: """Get the Batch client connection (cached).""" return self.get_conn() - def get_conn(self) -> BatchServiceClient: + def get_conn(self) -> BatchClient: """ Get the Batch client connection. @@ -99,23 +105,18 @@ def get_conn(self) -> BatchServiceClient: batch_account_url = self._get_field(conn.extra_dejson, "account_url") if not batch_account_url: - raise AirflowException("Batch Account URL parameter is missing.") + raise ValueError("Batch Account URL parameter is missing.") - credentials: batch_auth.SharedKeyCredentials | AzureIdentityCredentialAdapter + credential: AzureNamedKeyCredential | TokenCredential if all([conn.login, conn.password]): - credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password) + credential = AzureNamedKeyCredential(conn.login, conn.password) else: - managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id") - workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id") - credentials = AzureIdentityCredentialAdapter( - None, - resource_id="https://batch.core.windows.net/.default", - managed_identity_client_id=managed_identity_client_id, - workload_identity_tenant_id=workload_identity_tenant_id, + credential = get_sync_default_azure_credential( + managed_identity_client_id=conn.extra_dejson.get("managed_identity_client_id"), + workload_identity_tenant_id=conn.extra_dejson.get("workload_identity_tenant_id"), ) - batch_client = BatchServiceClient(credentials, batch_url=batch_account_url) - return batch_client + return BatchClient(endpoint=batch_account_url, credential=credential) def configure_pool( self, @@ -127,52 +128,40 @@ def configure_pool( sku_starts_with: str | None = None, vm_sku: str | None = None, vm_version: str | None = None, - os_family: str | None = None, - os_version: str | None = None, display_name: str | None = None, target_dedicated_nodes: int | None = None, use_latest_image_and_sku: bool = False, **kwargs, - ) -> PoolAddParameter: + ) -> BatchPoolCreateOptions: """ Configure a pool. :param pool_id: A string that uniquely identifies the Pool within the Account - :param vm_size: The size of virtual machines in the Pool. - :param display_name: The display name for the Pool - :param target_dedicated_nodes: The desired number of dedicated Compute Nodes in the Pool. - :param use_latest_image_and_sku: Whether to use the latest verified vm image and sku - :param vm_publisher: The publisher of the Azure Virtual Machines Marketplace Image. For example, Canonical or MicrosoftWindowsServer. - :param vm_offer: The offer type of the Azure Virtual Machines Marketplace Image. For example, UbuntuServer or WindowsServer. - :param sku_starts_with: The start name of the sku to search - :param vm_sku: The name of the virtual machine sku to use - :param vm_version: The version of the virtual machine - :param vm_version: str - :param vm_node_agent_sku_id: The node agent sku id of the virtual machine - - :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. - - :param os_version: The OS family version - """ + if "os_family" in kwargs or "os_version" in kwargs: + raise ValueError( + "Cloud Service pools (os_family/os_version) are no longer supported by Azure Batch. " + "Use vm_publisher/vm_offer/vm_sku/vm_node_agent_sku_id instead." + ) + if use_latest_image_and_sku: self.log.info("Using latest verified virtual machine image with node agent sku") sku_to_use, image_ref_to_use = self._get_latest_verified_image_vm_and_sku( publisher=vm_publisher, offer=vm_offer, sku_starts_with=sku_starts_with ) - pool = batch_models.PoolAddParameter( + return batch_models.BatchPoolCreateOptions( id=pool_id, vm_size=vm_size, display_name=display_name, @@ -183,57 +172,45 @@ def configure_pool( **kwargs, ) - elif os_family: - self.log.info( - "Using cloud service configuration to create pool, virtual machine configuration ignored" - ) - pool = batch_models.PoolAddParameter( - id=pool_id, - vm_size=vm_size, - display_name=display_name, - cloud_service_configuration=batch_models.CloudServiceConfiguration( - os_family=os_family, os_version=os_version - ), - target_dedicated_nodes=target_dedicated_nodes, - **kwargs, - ) - - else: - self.log.info("Using virtual machine configuration to create a pool") - pool = batch_models.PoolAddParameter( - id=pool_id, - vm_size=vm_size, - display_name=display_name, - virtual_machine_configuration=batch_models.VirtualMachineConfiguration( - image_reference=batch_models.ImageReference( - publisher=vm_publisher, - offer=vm_offer, - sku=vm_sku, - version=vm_version, - ), - node_agent_sku_id=vm_node_agent_sku_id, + self.log.info("Using virtual machine configuration to create a pool") + return batch_models.BatchPoolCreateOptions( + id=pool_id, + vm_size=vm_size, + display_name=display_name, + virtual_machine_configuration=batch_models.VirtualMachineConfiguration( + image_reference=batch_models.BatchVmImageReference( + publisher=vm_publisher, + offer=vm_offer, + sku=vm_sku, + version=vm_version, ), - target_dedicated_nodes=target_dedicated_nodes, - **kwargs, - ) - return pool + node_agent_sku_id=vm_node_agent_sku_id, + ), + target_dedicated_nodes=target_dedicated_nodes, + **kwargs, + ) - def create_pool(self, pool: PoolAddParameter) -> None: + def create_pool(self, pool: BatchPoolCreateOptions) -> None: """ Create a pool if not already existing. :param pool: the pool object to create - """ try: self.log.info("Attempting to create a pool: %s", pool.id) - self.connection.pool.add(pool) + self.connection.create_pool(pool) self.log.info("Created pool: %s", pool.id) - except batch_models.BatchErrorException as err: - if not err.error or err.error.code != "PoolExists": + except (ResourceExistsError, HttpResponseError) as err: + if not self._is_error_code(err, "PoolExists"): raise self.log.info("Pool %s already exists", pool.id) + @staticmethod + def _is_error_code(err: HttpResponseError, code: str) -> bool: + """Return True if the Batch error response carries the given code.""" + error = getattr(err, "model", None) + return getattr(error, "code", None) == code + def _get_latest_verified_image_vm_and_sku( self, publisher: str | None = None, @@ -249,8 +226,7 @@ def _get_latest_verified_image_vm_and_sku( For example, UbuntuServer or WindowsServer. :param sku_starts_with: The start name of the sku to search """ - options = batch_models.AccountListSupportedImagesOptions(filter="verificationType eq 'verified'") - images = self.connection.account.list_supported_images(account_list_supported_images_options=options) + images = self.connection.list_supported_images(filter="verificationType eq 'verified'") # pick the latest supported sku skus_to_use = [ (image.node_agent_sku_id, image.image_reference) @@ -269,16 +245,16 @@ def wait_for_all_node_state(self, pool_id: str, node_state: set) -> list: Wait for all nodes in a pool to reach given states. :param pool_id: A string that identifies the pool - :param node_state: A set of batch_models.ComputeNodeState + :param node_state: A set of batch_models.BatchNodeState """ self.log.info("waiting for all nodes in pool %s to reach one of: %s", pool_id, node_state) while True: # refresh pool to ensure that there is no resize error - pool = self.connection.pool.get(pool_id) + pool = self.connection.get_pool(pool_id) if pool.resize_errors is not None: resize_errors = "\n".join(repr(e) for e in pool.resize_errors) raise RuntimeError(f"resize error encountered for pool {pool.id}:\n{resize_errors}") - nodes = list(self.connection.compute_node.list(pool.id)) + nodes = list(self.connection.list_nodes(pool.id)) if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes): return nodes # Allow the timeout to be controlled by the AzureBatchOperator @@ -292,7 +268,7 @@ def configure_job( pool_id: str, display_name: str | None = None, **kwargs, - ) -> JobAddParameter: + ) -> BatchJobCreateOptions: """ Configure a job for use in the pool. @@ -300,25 +276,24 @@ def configure_job( :param pool_id: A string that identifies the pool :param display_name: The display name for the job """ - job = batch_models.JobAddParameter( + return batch_models.BatchJobCreateOptions( id=job_id, - pool_info=batch_models.PoolInformation(pool_id=pool_id), + pool_info=batch_models.BatchPoolInfo(pool_id=pool_id), display_name=display_name, **kwargs, ) - return job - def create_job(self, job: JobAddParameter) -> None: + def create_job(self, job: BatchJobCreateOptions) -> None: """ Create a job in the pool. :param job: The job object to create """ try: - self.connection.job.add(job) + self.connection.create_job(job) self.log.info("Job %s created", job.id) - except batch_models.BatchErrorException as err: - if not err.error or err.error.code != "JobExists": + except (ResourceExistsError, HttpResponseError) as err: + if not self._is_error_code(err, "JobExists"): raise self.log.info("Job %s already exists", job.id) @@ -329,7 +304,7 @@ def configure_task( display_name: str | None = None, container_settings=None, **kwargs, - ) -> TaskAddParameter: + ) -> BatchTaskCreateOptions: """ Create a task. @@ -341,7 +316,7 @@ def configure_task( this must be set as well. If the Pool that will run this Task doesn't have containerConfiguration set, this must not be set. """ - task = batch_models.TaskAddParameter( + task = batch_models.BatchTaskCreateOptions( id=task_id, command_line=command_line, display_name=display_name, @@ -351,7 +326,7 @@ def configure_task( self.log.info("Task created: %s", task_id) return task - def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None: + def add_single_task_to_job(self, job_id: str, task: BatchTaskCreateOptions) -> None: """ Add a single task to given job if it doesn't exist. @@ -359,13 +334,13 @@ def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None: :param task: The task to add """ try: - self.connection.task.add(job_id=job_id, task=task) - except batch_models.BatchErrorException as err: - if not err.error or err.error.code != "TaskExists": + self.connection.create_task(job_id, task) + except (ResourceExistsError, HttpResponseError) as err: + if not self._is_error_code(err, "TaskExists"): raise self.log.info("Task %s already exists", task.id) - def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.CloudTask]: + def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.BatchTask]: """ Wait for tasks in a particular job to complete. @@ -374,15 +349,16 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batc """ timeout_time = timezone.utcnow() + timedelta(minutes=timeout) while timezone.utcnow() < timeout_time: - tasks = list(self.connection.task.list(job_id)) + tasks = list(self.connection.list_tasks(job_id)) - incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed] + incomplete_tasks = [task for task in tasks if task.state != batch_models.BatchTaskState.COMPLETED] if not incomplete_tasks: # detect if any task in job has failed fail_tasks = [ task for task in tasks - if task.execution_info.result == batch_models.TaskExecutionResult.failure + if task.execution_info + and task.execution_info.result == batch_models.BatchTaskExecutionResult.FAILURE ] return fail_tasks for task in incomplete_tasks: @@ -393,12 +369,12 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batc def test_connection(self): """Test a configured Azure Batch connection.""" try: - # Attempt to list existing jobs under the configured Batch account and retrieve + # Attempt to list existing jobs under the configured Batch account and retrieve # the first in the returned iterator. The Azure Batch API does allow for creation of a - # BatchServiceClient with incorrect values but then will fail properly once items are + # BatchClient with incorrect values but then will fail properly once items are # retrieved using the client. We need to _actually_ try to retrieve an object to properly # test the connection. - next(self.get_conn().job.list(), None) + next(iter(self.get_conn().list_jobs()), None) except Exception as e: return False, str(e) return True, "Successfully connected to Azure Batch." diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py index e5b36841c66cf..34d76652883f7 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py @@ -86,8 +86,6 @@ class AzureBatchOperator(BaseOperator): :param vm_version: The version of the virtual machine :param vm_version: str | None :param vm_node_agent_sku_id: The node agent sku id of the virtual machine - :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. - :param os_version: The OS family version :param timeout: The amount of time to wait for the job to complete in minutes. Default is 25 :param should_delete_job: Whether to delete job after execution. Default is False :param should_delete_pool: Whether to delete pool after execution of jobs. Default is False @@ -116,16 +114,14 @@ def __init__( sku_starts_with: str | None = None, vm_sku: str | None = None, vm_version: str | None = None, - os_family: str | None = None, - os_version: str | None = None, batch_pool_display_name: str | None = None, batch_job_display_name: str | None = None, - batch_job_manager_task: batch_models.JobManagerTask | None = None, - batch_job_preparation_task: batch_models.JobPreparationTask | None = None, - batch_job_release_task: batch_models.JobReleaseTask | None = None, + batch_job_manager_task: batch_models.BatchJobManagerTask | None = None, + batch_job_preparation_task: batch_models.BatchJobPreparationTask | None = None, + batch_job_release_task: batch_models.BatchJobReleaseTask | None = None, batch_task_display_name: str | None = None, - batch_task_container_settings: batch_models.TaskContainerSettings | None = None, - batch_start_task: batch_models.StartTask | None = None, + batch_task_container_settings: batch_models.BatchTaskContainerSettings | None = None, + batch_start_task: batch_models.BatchStartTask | None = None, batch_max_retries: int = 3, batch_task_resource_files: list[batch_models.ResourceFile] | None = None, batch_task_output_files: list[batch_models.OutputFile] | None = None, @@ -141,6 +137,11 @@ def __init__( should_delete_pool: bool = False, **kwargs, ) -> None: + if "os_family" in kwargs or "os_version" in kwargs: + raise ValueError( + "Cloud Service pools (os_family/os_version) are no longer supported by Azure Batch. " + "Use vm_publisher/vm_offer/vm_sku/vm_node_agent_sku_id instead." + ) super().__init__(**kwargs) self.batch_pool_id = batch_pool_id self.batch_pool_vm_size = batch_pool_vm_size @@ -171,8 +172,6 @@ def __init__( self.vm_sku = vm_sku self.vm_version = vm_version self.vm_node_agent_sku_id = vm_node_agent_sku_id - self.os_family = os_family - self.os_version = os_version self.timeout = timeout self.should_delete_job = should_delete_job self.should_delete_pool = should_delete_pool @@ -183,14 +182,8 @@ def hook(self) -> AzureBatchHook: return AzureBatchHook(self.azure_batch_conn_id) def _check_inputs(self) -> Any: - if not self.os_family and not self.vm_publisher: - raise AirflowException("You must specify either vm_publisher or os_family") - if self.os_family and self.vm_publisher: - raise AirflowException( - "Cloud service configuration and virtual machine configuration " - "are mutually exclusive. You must specify either of os_family and" - " vm_publisher" - ) + if not self.vm_publisher: + raise ValueError("vm_publisher is required") if self.use_latest_image: if not self.vm_publisher or not self.vm_offer: @@ -243,7 +236,6 @@ def _check_inputs(self) -> Any: def execute(self, context: Context) -> None: self._check_inputs() - self.hook.connection.config.retry_policy = self.batch_max_retries pool = self.hook.configure_pool( pool_id=self.batch_pool_id, @@ -257,8 +249,6 @@ def execute(self, context: Context) -> None: vm_sku=self.vm_sku, vm_version=self.vm_version, vm_node_agent_sku_id=self.vm_node_agent_sku_id, - os_family=self.os_family, - os_version=self.os_version, target_low_priority_nodes=self.target_low_priority_nodes, enable_auto_scale=self.enable_auto_scale, auto_scale_formula=self.auto_scale_formula, @@ -269,9 +259,9 @@ def execute(self, context: Context) -> None: self.hook.wait_for_all_node_state( self.batch_pool_id, { - batch_models.ComputeNodeState.start_task_failed, - batch_models.ComputeNodeState.unusable, - batch_models.ComputeNodeState.idle, + batch_models.BatchNodeState.START_TASK_FAILED, + batch_models.BatchNodeState.UNUSABLE, + batch_models.BatchNodeState.IDLE, }, ) # Create job if not already exist @@ -309,10 +299,11 @@ def execute(self, context: Context) -> None: raise AirflowException(f"Job fail. The failed task are: {fail_tasks}") def on_kill(self) -> None: - response = self.hook.connection.job.terminate( - job_id=self.batch_job_id, terminate_reason="Job killed by user" - ) - self.log.info("Azure Batch job (%s) terminated: %s", self.batch_job_id, response) + self.hook.connection.begin_terminate_job( + self.batch_job_id, + options=batch_models.BatchJobTerminateOptions(termination_reason="Job killed by user"), + ).result() + self.log.info("Azure Batch job (%s) terminated", self.batch_job_id) def clean_up(self, pool_id: str | None = None, job_id: str | None = None) -> None: """ @@ -320,11 +311,10 @@ def clean_up(self, pool_id: str | None = None, job_id: str | None = None) -> Non :param pool_id: The id of the pool to delete :param job_id: The id of the job to delete - """ if job_id: self.log.info("Deleting job: %s", job_id) - self.hook.connection.job.delete(job_id) + self.hook.connection.begin_delete_job(job_id).result() if pool_id: self.log.info("Deleting pool: %s", pool_id) - self.hook.connection.pool.delete(pool_id) + self.hook.connection.begin_delete_pool(pool_id).result() diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_batch.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_batch.py index d17f50665ffa6..3a50be1ecdfa5 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_batch.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_batch.py @@ -21,7 +21,7 @@ from unittest.mock import PropertyMock import pytest -from azure.batch import BatchServiceClient, models as batch_models +from azure.batch import BatchClient, models as batch_models from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook @@ -34,7 +34,6 @@ class TestAzureBatchHook: def setup_test_cases(self, create_mock_connections): # set up the test variable self.test_vm_conn_id = "test_azure_batch_vm" - self.test_cloud_conn_id = "test_azure_batch_cloud" self.test_account_name = "test_account_name" self.test_account_key = "test_account_key" self.test_account_url = "http://test-endpoint:29000" @@ -42,52 +41,46 @@ def setup_test_cases(self, create_mock_connections): self.test_vm_publisher = "test.vm.publisher" self.test_vm_offer = "test.vm.offer" self.test_vm_sku = "test-sku" - self.test_cloud_os_family = "test-family" - self.test_cloud_os_version = "test-version" self.test_node_agent_sku = "test-node-agent-sku" create_mock_connections( - # connect with vm configuration Connection( conn_id=self.test_vm_conn_id, conn_type="azure-batch", - extra={"account_url": self.test_account_url}, - ), - # connect with cloud service - Connection( - conn_id=self.test_cloud_conn_id, - conn_type="azure-batch", + login=self.test_account_name, + password=self.test_account_key, extra={"account_url": self.test_account_url}, ), ) def test_connection_and_client(self): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) - assert isinstance(hook.get_conn(), BatchServiceClient) + assert isinstance(hook.get_conn(), BatchClient) conn = hook.connection - assert isinstance(conn, BatchServiceClient) + assert isinstance(conn, BatchClient) assert hook.connection is conn, "`connection` property should be cached" - @mock.patch(f"{MODULE}.batch_auth.SharedKeyCredentials") - @mock.patch(f"{MODULE}.AzureIdentityCredentialAdapter") - def test_fallback_to_azure_identity_credential_adppter_when_name_and_key_is_not_provided( - self, mock_azure_identity_credential_adapter, mock_shared_key_credentials + @mock.patch(f"{MODULE}.AzureNamedKeyCredential") + @mock.patch(f"{MODULE}.get_sync_default_azure_credential") + def test_fallback_to_azure_identity_credential_when_name_and_key_is_not_provided( + self, mock_default_credential, mock_named_key_credential, create_mock_connections ): - self.test_account_name = None - self.test_account_key = None + conn_id = "test_azure_batch_no_creds" + create_mock_connections( + Connection( + conn_id=conn_id, + conn_type="azure-batch", + extra={"account_url": self.test_account_url}, + ), + ) - hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) - assert isinstance(hook.get_conn(), BatchServiceClient) - mock_azure_identity_credential_adapter.assert_called_with( - None, - resource_id="https://batch.core.windows.net/.default", + hook = AzureBatchHook(azure_batch_conn_id=conn_id) + assert isinstance(hook.get_conn(), BatchClient) + mock_default_credential.assert_called_with( managed_identity_client_id=None, workload_identity_tenant_id=None, ) - assert not mock_shared_key_credentials.auth.called - - self.test_account_name = "test_account_name" - self.test_account_key = "test_account_key" + mock_named_key_credential.assert_not_called() def test_configure_pool_with_vm_config(self): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) @@ -100,27 +93,25 @@ def test_configure_pool_with_vm_config(self): vm_offer="test.vm.offer", sku_starts_with="test-sku", ) - assert isinstance(pool, batch_models.PoolAddParameter) + assert isinstance(pool, batch_models.BatchPoolCreateOptions) - def test_configure_pool_with_cloud_config(self): - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - pool = hook.configure_pool( - pool_id="mypool", - vm_size="test_vm_size", - vm_node_agent_sku_id=self.test_vm_sku, - target_dedicated_nodes=1, - vm_publisher="test.vm.publisher", - vm_offer="test.vm.offer", - sku_starts_with="test-sku", - ) - assert isinstance(pool, batch_models.PoolAddParameter) + def test_configure_pool_rejects_cloud_service_config(self): + hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) + with pytest.raises(ValueError, match="Cloud Service pools"): + hook.configure_pool( + pool_id="mypool", + vm_size="test_vm_size", + vm_node_agent_sku_id=self.test_vm_sku, + target_dedicated_nodes=1, + os_family="4", + ) def test_configure_pool_with_latest_vm(self): with mock.patch( "airflow.providers.microsoft.azure.hooks." "batch.AzureBatchHook._get_latest_verified_image_vm_and_sku" ) as mock_getvm: - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) + hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) getvm_instance = mock_getvm getvm_instance.return_value = ["test-image", "test-sku"] pool = hook.configure_pool( @@ -132,12 +123,12 @@ def test_configure_pool_with_latest_vm(self): vm_offer="test.vm.offer", sku_starts_with="test-sku", ) - assert isinstance(pool, batch_models.PoolAddParameter) + assert isinstance(pool, batch_models.BatchPoolCreateOptions) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_create_pool_with_vm_config(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) - mock_instance = mock_batch.return_value.pool.add + mock_create_pool = mock_batch.return_value.create_pool pool = hook.configure_pool( pool_id="mypool", vm_size="test_vm_size", @@ -148,26 +139,10 @@ def test_create_pool_with_vm_config(self, mock_batch): sku_starts_with="test-sku", ) hook.create_pool(pool=pool) - mock_instance.assert_called_once_with(pool) - - @mock.patch(f"{MODULE}.BatchServiceClient") - def test_create_pool_with_cloud_config(self, mock_batch): - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - mock_instance = mock_batch.return_value.pool.add - pool = hook.configure_pool( - pool_id="mypool", - vm_size="test_vm_size", - vm_node_agent_sku_id=self.test_vm_sku, - target_dedicated_nodes=1, - vm_publisher="test.vm.publisher", - vm_offer="test.vm.offer", - sku_starts_with="test-sku", - ) - hook.create_pool(pool=pool) - mock_instance.assert_called_once_with(pool) + mock_create_pool.assert_called_once_with(pool) @mock.patch(f"{MODULE}.time.sleep", return_value=None) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_wait_for_all_nodes_success_immediately(self, _mock_batch, mock_sleep): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) @@ -176,24 +151,24 @@ def test_wait_for_all_nodes_success_immediately(self, _mock_batch, mock_sleep): pool.target_dedicated_nodes = 2 pool.resize_errors = None - node_idle_1 = mock.Mock(state=batch_models.ComputeNodeState.idle) - node_idle_2 = mock.Mock(state=batch_models.ComputeNodeState.idle) + node_idle_1 = mock.Mock(state=batch_models.BatchNodeState.IDLE) + node_idle_2 = mock.Mock(state=batch_models.BatchNodeState.IDLE) - hook.connection.pool.get.return_value = pool - hook.connection.compute_node.list.return_value = [node_idle_1, node_idle_2] + hook.connection.get_pool.return_value = pool + hook.connection.list_nodes.return_value = [node_idle_1, node_idle_2] nodes = hook.wait_for_all_node_state( pool_id="mypool", - node_state={batch_models.ComputeNodeState.idle}, + node_state={batch_models.BatchNodeState.IDLE}, ) assert nodes == [node_idle_1, node_idle_2] - hook.connection.pool.get.assert_called_once_with("mypool") - hook.connection.compute_node.list.assert_called_once_with("mypool") + hook.connection.get_pool.assert_called_once_with("mypool") + hook.connection.list_nodes.assert_called_once_with("mypool") assert mock_sleep.call_count == 0 @mock.patch(f"{MODULE}.time.sleep", return_value=None) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_wait_for_all_nodes_waits_for_node_count(self, _mock_batch, mock_sleep): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) @@ -202,29 +177,29 @@ def test_wait_for_all_nodes_waits_for_node_count(self, _mock_batch, mock_sleep): pool.target_dedicated_nodes = 2 pool.resize_errors = None - node_idle_1 = mock.Mock(state=batch_models.ComputeNodeState.idle) - node_idle_2 = mock.Mock(state=batch_models.ComputeNodeState.idle) + node_idle_1 = mock.Mock(state=batch_models.BatchNodeState.IDLE) + node_idle_2 = mock.Mock(state=batch_models.BatchNodeState.IDLE) - hook.connection.pool.get.return_value = pool + hook.connection.get_pool.return_value = pool # First call must return only 1 node. # Second call must return 2 nodes. - hook.connection.compute_node.list.side_effect = [ + hook.connection.list_nodes.side_effect = [ [node_idle_1], [node_idle_1, node_idle_2], ] nodes = hook.wait_for_all_node_state( pool_id="mypool", - node_state={batch_models.ComputeNodeState.idle}, + node_state={batch_models.BatchNodeState.IDLE}, ) assert nodes == [node_idle_1, node_idle_2] - assert hook.connection.compute_node.list.call_count == 2 + assert hook.connection.list_nodes.call_count == 2 assert mock_sleep.call_count == 1 @mock.patch(f"{MODULE}.time.sleep", return_value=None) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_wait_for_all_nodes_retries_until_ready(self, _mock_batch, mock_sleep): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) @@ -233,32 +208,32 @@ def test_wait_for_all_nodes_retries_until_ready(self, _mock_batch, mock_sleep): pool.target_dedicated_nodes = 2 pool.resize_errors = None - node_starting_1 = mock.Mock(state=batch_models.ComputeNodeState.starting) - node_starting_2 = mock.Mock(state=batch_models.ComputeNodeState.starting) + node_starting_1 = mock.Mock(state=batch_models.BatchNodeState.STARTING) + node_starting_2 = mock.Mock(state=batch_models.BatchNodeState.STARTING) - node_idle_1 = mock.Mock(state=batch_models.ComputeNodeState.idle) - node_idle_2 = mock.Mock(state=batch_models.ComputeNodeState.idle) + node_idle_1 = mock.Mock(state=batch_models.BatchNodeState.IDLE) + node_idle_2 = mock.Mock(state=batch_models.BatchNodeState.IDLE) - hook.connection.pool.get.return_value = pool + hook.connection.get_pool.return_value = pool # Nodes are not ready in the first poll. # Nodes are ready in the second poll. - hook.connection.compute_node.list.side_effect = [ + hook.connection.list_nodes.side_effect = [ [node_starting_1, node_starting_2], [node_idle_1, node_idle_2], ] nodes = hook.wait_for_all_node_state( pool_id="mypool", - node_state={batch_models.ComputeNodeState.idle}, + node_state={batch_models.BatchNodeState.IDLE}, ) assert nodes == [node_idle_1, node_idle_2] - assert hook.connection.compute_node.list.call_count == 2 + assert hook.connection.list_nodes.call_count == 2 assert mock_sleep.call_count == 1 @mock.patch(f"{MODULE}.time.sleep", return_value=None) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_wait_for_all_nodes_resize_error(self, _mock_batch, mock_sleep): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) @@ -267,102 +242,89 @@ def test_wait_for_all_nodes_resize_error(self, _mock_batch, mock_sleep): pool.target_dedicated_nodes = 2 pool.resize_errors = ["resize failed"] - hook.connection.pool.get.return_value = pool + hook.connection.get_pool.return_value = pool with pytest.raises(RuntimeError, match="resize error encountered"): hook.wait_for_all_node_state( pool_id="mypool", - node_state={batch_models.ComputeNodeState.idle}, + node_state={batch_models.BatchNodeState.IDLE}, ) assert mock_sleep.call_count == 0 - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_job_configuration_and_create_job(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) - mock_instance = mock_batch.return_value.job.add + mock_create_job = mock_batch.return_value.create_job job = hook.configure_job(job_id="myjob", pool_id="mypool") hook.create_job(job) - assert isinstance(job, batch_models.JobAddParameter) - mock_instance.assert_called_once_with(job) + assert isinstance(job, batch_models.BatchJobCreateOptions) + mock_create_job.assert_called_once_with(job) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_add_single_task_to_job(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) - mock_instance = mock_batch.return_value.task.add + mock_create_task = mock_batch.return_value.create_task task = hook.configure_task(task_id="mytask", command_line="echo hello") hook.add_single_task_to_job(job_id="myjob", task=task) - assert isinstance(task, batch_models.TaskAddParameter) - mock_instance.assert_called_once_with(job_id="myjob", task=task) + assert isinstance(task, batch_models.BatchTaskCreateOptions) + mock_create_task.assert_called_once_with("myjob", task) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_wait_for_all_task_to_complete_timeout(self, mock_batch): - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) + hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) with pytest.raises(TimeoutError): hook.wait_for_job_tasks_to_complete("myjob", -1) - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_wait_for_all_task_to_complete_all_success(self, mock_batch): - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - hook.connection.task.list.return_value = iter( - [ - batch_models.CloudTask( - id="mytask_1", - execution_info=batch_models.TaskExecutionInformation( - retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success - ), - state=batch_models.TaskState.completed, - ), - batch_models.CloudTask( - id="mytask_2", - execution_info=batch_models.TaskExecutionInformation( - retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success - ), - state=batch_models.TaskState.completed, - ), - ] + hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) + task1 = mock.Mock( + id="mytask_1", + state=batch_models.BatchTaskState.COMPLETED, + execution_info=mock.Mock(result=batch_models.BatchTaskExecutionResult.SUCCESS), ) + task2 = mock.Mock( + id="mytask_2", + state=batch_models.BatchTaskState.COMPLETED, + execution_info=mock.Mock(result=batch_models.BatchTaskExecutionResult.SUCCESS), + ) + hook.connection.list_tasks.return_value = iter([task1, task2]) results = hook.wait_for_job_tasks_to_complete("myjob", 60) assert results == [] - hook.connection.task.list.assert_called_once_with("myjob") + hook.connection.list_tasks.assert_called_once_with("myjob") - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_wait_for_all_task_to_complete_failures(self, mock_batch): - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - tasks = [ - batch_models.CloudTask( - id="mytask_1", - execution_info=batch_models.TaskExecutionInformation( - retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success - ), - state=batch_models.TaskState.completed, - ), - batch_models.CloudTask( - id="mytask_2", - execution_info=batch_models.TaskExecutionInformation( - retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.failure - ), - state=batch_models.TaskState.completed, - ), - ] - hook.connection.task.list.return_value = iter(tasks) + hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) + task1 = mock.Mock( + id="mytask_1", + state=batch_models.BatchTaskState.COMPLETED, + execution_info=mock.Mock(result=batch_models.BatchTaskExecutionResult.SUCCESS), + ) + task2 = mock.Mock( + id="mytask_2", + state=batch_models.BatchTaskState.COMPLETED, + execution_info=mock.Mock(result=batch_models.BatchTaskExecutionResult.FAILURE), + ) + hook.connection.list_tasks.return_value = iter([task1, task2]) results = hook.wait_for_job_tasks_to_complete("myjob", 60) - assert results == [tasks[1]] - hook.connection.task.list.assert_called_once_with("myjob") + assert results == [task2] + hook.connection.list_tasks.assert_called_once_with("myjob") - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_connection_success(self, mock_batch): - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - hook.connection.job.return_value = {} + hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) + hook.connection.list_jobs.return_value = iter([]) status, msg = hook.test_connection() assert status is True assert msg == "Successfully connected to Azure Batch." - @mock.patch(f"{MODULE}.BatchServiceClient") + @mock.patch(f"{MODULE}.BatchClient") def test_connection_failure(self, mock_batch): - hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - hook.connection.job.list = PropertyMock(side_effect=Exception("Authentication failed.")) + hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) + hook.connection.list_jobs = PropertyMock(side_effect=Exception("Authentication failed.")) status, msg = hook.test_connection() assert status is False assert msg == "Authentication failed." diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py index 160aea3df7e79..74eeb4b34c43a 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py @@ -40,44 +40,30 @@ @pytest.fixture -def mocked_batch_service_client(): - with mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") as m: +def mocked_batch_client(): + with mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchClient") as m: yield m class TestAzureBatchOperator: - # set up the test environment @pytest.fixture(autouse=True) - def setup_test_cases(self, mocked_batch_service_client, create_mock_connections): - # set up mocked Azure Batch client - self.batch_client = mock.MagicMock(name="FakeBatchServiceClient") - mocked_batch_service_client.return_value = self.batch_client + def setup_test_cases(self, mocked_batch_client, create_mock_connections): + self.batch_client = mock.MagicMock(name="FakeBatchClient") + mocked_batch_client.return_value = self.batch_client - # set up the test variable self.test_vm_conn_id = "test_azure_batch_vm2" - self.test_cloud_conn_id = "test_azure_batch_cloud2" - self.test_account_name = "test_account_name" - self.test_account_key = "test_account_key" self.test_account_url = "http://test-endpoint:29000" - self.test_vm_size = "test-vm-size" self.test_vm_publisher = "test.vm.publisher" self.test_vm_offer = "test.vm.offer" self.test_vm_sku = "test-sku" - self.test_cloud_os_family = "test-family" - self.test_cloud_os_version = "test-version" self.test_node_agent_sku = "test-node-agent-sku" create_mock_connections( - # connect with vm configuration Connection( conn_id=self.test_vm_conn_id, conn_type="azure_batch", - extra=json.dumps({"account_url": self.test_account_url}), - ), - # connect with cloud service - Connection( - conn_id=self.test_cloud_conn_id, - conn_type="azure_batch", + login="test_account_name", + password="test_account_key", extra=json.dumps({"account_url": self.test_account_url}), ), ) @@ -98,46 +84,40 @@ def setup_test_cases(self, mocked_batch_service_client, create_mock_connections) target_dedicated_nodes=1, timeout=2, ) - self.operator2_pass = AzureBatchOperator( + self.operator_auto_scale = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, + vm_publisher=self.test_vm_publisher, + vm_offer=self.test_vm_offer, + vm_sku=self.test_vm_sku, vm_node_agent_sku_id=self.test_node_agent_sku, - os_family="4", + sku_starts_with=self.test_vm_sku, batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, enable_auto_scale=True, auto_scale_formula=FORMULA, timeout=2, ) - self.operator2_no_formula = AzureBatchOperator( + self.operator_auto_scale_no_formula = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, + vm_publisher=self.test_vm_publisher, + vm_offer=self.test_vm_offer, + vm_sku=self.test_vm_sku, vm_node_agent_sku_id=self.test_node_agent_sku, - os_family="4", + sku_starts_with=self.test_vm_sku, batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, enable_auto_scale=True, timeout=2, ) - self.operator_fail = AzureBatchOperator( - task_id=TASK_ID, - batch_pool_id=BATCH_POOL_ID, - batch_pool_vm_size=BATCH_VM_SIZE, - batch_job_id=BATCH_JOB_ID, - batch_task_id=BATCH_TASK_ID, - vm_node_agent_sku_id=self.test_node_agent_sku, - os_family="4", - batch_task_command_line="echo hello", - azure_batch_conn_id=self.test_vm_conn_id, - timeout=2, - ) - self.operator_mutual_exclusive = AzureBatchOperator( + self.operator_no_dedicated_no_auto = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, @@ -147,14 +127,12 @@ def setup_test_cases(self, mocked_batch_service_client, create_mock_connections) vm_offer=self.test_vm_offer, vm_sku=self.test_vm_sku, vm_node_agent_sku_id=self.test_node_agent_sku, - os_family="5", sku_starts_with=self.test_vm_sku, batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, - target_dedicated_nodes=1, timeout=2, ) - self.operator_invalid = AzureBatchOperator( + self.operator_no_publisher = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, @@ -171,19 +149,18 @@ def setup_test_cases(self, mocked_batch_service_client, create_mock_connections) def test_execute_without_failures(self, wait_mock): wait_mock.return_value = True # No wait self.operator.execute(None) - # test pool creation - self.batch_client.pool.add.assert_called() - self.batch_client.job.add.assert_called() - self.batch_client.task.add.assert_called() + # test pool/job/task creation use the new flat method names + self.batch_client.create_pool.assert_called() + self.batch_client.create_job.assert_called() + self.batch_client.create_task.assert_called() @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") - def test_execute_without_failures_2(self, wait_mock): + def test_execute_with_auto_scale(self, wait_mock): wait_mock.return_value = True # No wait - self.operator2_pass.execute(None) - # test pool creation - self.batch_client.pool.add.assert_called() - self.batch_client.job.add.assert_called() - self.batch_client.task.add.assert_called() + self.operator_auto_scale.execute(None) + self.batch_client.create_pool.assert_called() + self.batch_client.create_job.assert_called() + self.batch_client.create_task.assert_called() @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") def test_execute_with_failures(self, wait_mock): @@ -199,17 +176,16 @@ def test_execute_with_failures(self, wait_mock): @mock.patch.object(AzureBatchOperator, "clean_up") def test_execute_with_cleaning(self, mock_clean, wait_mock): wait_mock.return_value = True # No wait - # Remove pool id self.operator.should_delete_job = True self.operator.execute(None) mock_clean.assert_called() mock_clean.assert_called_once_with(job_id=BATCH_JOB_ID) @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") - def test_operator_fails(self, wait_mock): + def test_operator_fails_no_dedicated_no_auto(self, wait_mock): wait_mock.return_value = True with pytest.raises(AirflowException) as ctx: - self.operator_fail.execute(None) + self.operator_no_dedicated_no_auto.execute(None) assert ( str(ctx.value) == "Either target_dedicated_nodes or enable_auto_scale must be set. None was set" ) @@ -218,32 +194,51 @@ def test_operator_fails(self, wait_mock): def test_operator_fails_no_formula(self, wait_mock): wait_mock.return_value = True with pytest.raises(AirflowException) as ctx: - self.operator2_no_formula.execute(None) + self.operator_auto_scale_no_formula.execute(None) assert str(ctx.value) == "The auto_scale_formula is required when enable_auto_scale is set" - @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") - def test_operator_fails_mutual_exclusive(self, wait_mock): - wait_mock.return_value = True - with pytest.raises(AirflowException) as ctx: - self.operator_mutual_exclusive.execute(None) - assert ( - str(ctx.value) == "Cloud service configuration and virtual machine configuration " - "are mutually exclusive. You must specify either of os_family and" - " vm_publisher" - ) + def test_operator_construction_rejects_os_family(self): + with pytest.raises(ValueError, match="Cloud Service pools"): + AzureBatchOperator( + task_id=TASK_ID, + batch_pool_id=BATCH_POOL_ID, + batch_pool_vm_size=BATCH_VM_SIZE, + batch_job_id=BATCH_JOB_ID, + batch_task_id=BATCH_TASK_ID, + vm_node_agent_sku_id=self.test_node_agent_sku, + os_family="4", + batch_task_command_line="echo hello", + azure_batch_conn_id=self.test_vm_conn_id, + target_dedicated_nodes=1, + timeout=2, + ) + + def test_operator_construction_rejects_os_version(self): + with pytest.raises(ValueError, match="Cloud Service pools"): + AzureBatchOperator( + task_id=TASK_ID, + batch_pool_id=BATCH_POOL_ID, + batch_pool_vm_size=BATCH_VM_SIZE, + batch_job_id=BATCH_JOB_ID, + batch_task_id=BATCH_TASK_ID, + vm_node_agent_sku_id=self.test_node_agent_sku, + os_version="2", + batch_task_command_line="echo hello", + azure_batch_conn_id=self.test_vm_conn_id, + target_dedicated_nodes=1, + timeout=2, + ) @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") - def test_operator_fails_invalid_args(self, wait_mock): + def test_operator_fails_no_publisher(self, wait_mock): wait_mock.return_value = True - with pytest.raises(AirflowException) as ctx: - self.operator_invalid.execute(None) - assert str(ctx.value) == "You must specify either vm_publisher or os_family" + with pytest.raises(ValueError, match="vm_publisher is required"): + self.operator_no_publisher.execute(None) def test_cleaning_works(self): self.operator.clean_up(job_id="myjob") - self.batch_client.job.delete.assert_called_once_with("myjob") + self.batch_client.begin_delete_job.assert_called_once_with("myjob") + self.batch_client.begin_delete_job.return_value.result.assert_called_once() self.operator.clean_up("mypool") - self.batch_client.pool.delete.assert_called_once_with("mypool") - self.operator.clean_up("mypool", "myjob") - self.batch_client.job.delete.assert_called_with("myjob") - self.batch_client.pool.delete.assert_called_with("mypool") + self.batch_client.begin_delete_pool.assert_called_once_with("mypool") + self.batch_client.begin_delete_pool.return_value.result.assert_called_once() diff --git a/scripts/ci/prek/known_airflow_exceptions.txt b/scripts/ci/prek/known_airflow_exceptions.txt index 6175e6ad00c00..867f62db08d23 100644 --- a/scripts/ci/prek/known_airflow_exceptions.txt +++ b/scripts/ci/prek/known_airflow_exceptions.txt @@ -346,14 +346,13 @@ providers/jenkins/src/airflow/providers/jenkins/sensors/jenkins.py::1 providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py::4 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/adx.py::4 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py::1 -providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py::1 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py::1 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/cosmos.py::1 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py::26 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py::1 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py::3 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py::2 -providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::10 +providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::8 providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py::10 providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py::1 providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py::1