diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py index efea53560b410..29ecfa46516c0 100644 --- a/airflow/providers/apache/beam/hooks/beam.py +++ b/airflow/providers/apache/beam/hooks/beam.py @@ -508,6 +508,25 @@ async def start_python_pipeline_async( ) return return_code + async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: str | None = None): + """ + Start Apache Beam Java pipeline. + + :param variables: Variables passed to the job. + :param jar: Name of the jar for the pipeline. + :param job_class: Name of the java class for the pipeline. + :return: Beam command execution return code. + """ + if "labels" in variables: + variables["labels"] = json.dumps(variables["labels"], separators=(",", ":")) + + command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar] + return_code = await self.start_pipeline_async( + variables=variables, + command_prefix=command_prefix, + ) + return return_code + async def start_pipeline_async( self, variables: dict, diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index 876f47fa2db7d..bf75f3caa74fb 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -34,7 +34,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType -from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger +from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger from airflow.providers.google.cloud.hooks.dataflow import ( DataflowHook, process_line_and_extract_dataflow_job_id_callback, @@ -239,6 +239,22 @@ def _init_pipeline_options( check_job_status_callback, ) + def execute_complete(self, context: Context, event: dict[str, Any]): + """ + Execute when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) + return {"dataflow_job_id": self.dataflow_job_id} + class BeamRunPythonPipelineOperator(BeamBasePipelineOperator): """ @@ -323,7 +339,7 @@ def __init__( self.deferrable = deferrable def execute(self, context: Context): - """Execute the Apache Beam Pipeline.""" + """Execute the Apache Beam Python Pipeline.""" ( self.is_dataflow, self.dataflow_job_name, @@ -408,7 +424,7 @@ async def execute_async(self, context: Context): ) with self.dataflow_hook.provide_authorized_gcloud(): self.defer( - trigger=BeamPipelineTrigger( + trigger=BeamPythonPipelineTrigger( variables=self.snake_case_pipeline_options, py_file=self.py_file, py_options=self.py_options, @@ -421,7 +437,7 @@ async def execute_async(self, context: Context): ) else: self.defer( - trigger=BeamPipelineTrigger( + trigger=BeamPythonPipelineTrigger( variables=self.snake_case_pipeline_options, py_file=self.py_file, py_options=self.py_options, @@ -433,22 +449,6 @@ async def execute_async(self, context: Context): method_name="execute_complete", ) - def execute_complete(self, context: Context, event: dict[str, Any]): - """ - Execute when the trigger fires - returns immediately. - - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info( - "%s completed with response %s ", - self.task_id, - event["message"], - ) - return {"dataflow_job_id": self.dataflow_job_id} - def on_kill(self) -> None: if self.dataflow_hook and self.dataflow_job_id: self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id) @@ -509,6 +509,7 @@ def __init__( pipeline_options: dict | None = None, gcp_conn_id: str = "google_cloud_default", dataflow_config: DataflowConfiguration | dict | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__( @@ -521,61 +522,55 @@ def __init__( ) self.jar = jar self.job_class = job_class + self.deferrable = deferrable def execute(self, context: Context): - """Execute the Apache Beam Pipeline.""" + """Execute the Apache Beam Python Pipeline.""" ( - is_dataflow, - dataflow_job_name, - pipeline_options, - process_line_callback, + self.is_dataflow, + self.dataflow_job_name, + self.pipeline_options, + self.process_line_callback, _, ) = self._init_pipeline_options() - if not self.beam_hook: raise AirflowException("Beam hook is not defined.") + if self.deferrable: + asyncio.run(self.execute_async(context)) + else: + return self.execute_sync(context) + def execute_sync(self, context: Context): + """Execute the Apache Beam Pipeline.""" with ExitStack() as exit_stack: if self.jar.lower().startswith("gs://"): gcs_hook = GCSHook(self.gcp_conn_id) tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar)) self.jar = tmp_gcs_file.name - if is_dataflow and self.dataflow_hook: - is_running = False - if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob: - is_running = ( - # The reason for disable=no-value-for-parameter is that project_id parameter is - # required but here is not passed, moreover it cannot be passed here. - # This method is wrapped by @_fallback_to_project_id_from_variables decorator which - # fallback project_id value from variables and raise error if project_id is - # defined both in variables and as parameter (here is already defined in variables) - self.dataflow_hook.is_job_dataflow_running( - name=self.dataflow_config.job_name, - variables=pipeline_options, - ) + if self.is_dataflow and self.dataflow_hook: + is_running = self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun + while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun: + # The reason for disable=no-value-for-parameter is that project_id parameter is + # required but here is not passed, moreover it cannot be passed here. + # This method is wrapped by @_fallback_to_project_id_from_variables decorator which + # fallback project_id value from variables and raise error if project_id is + # defined both in variables and as parameter (here is already defined in variables) + is_running = self.dataflow_hook.is_job_dataflow_running( + name=self.dataflow_config.job_name, + variables=self.pipeline_options, ) - while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun: - # The reason for disable=no-value-for-parameter is that project_id parameter is - # required but here is not passed, moreover it cannot be passed here. - # This method is wrapped by @_fallback_to_project_id_from_variables decorator which - # fallback project_id value from variables and raise error if project_id is - # defined both in variables and as parameter (here is already defined in variables) - - is_running = self.dataflow_hook.is_job_dataflow_running( - name=self.dataflow_config.job_name, - variables=pipeline_options, - ) + if not is_running: - pipeline_options["jobName"] = dataflow_job_name + self.pipeline_options["jobName"] = self.dataflow_job_name with self.dataflow_hook.provide_authorized_gcloud(): self.beam_hook.start_java_pipeline( - variables=pipeline_options, + variables=self.pipeline_options, jar=self.jar, job_class=self.job_class, - process_line_callback=process_line_callback, + process_line_callback=self.process_line_callback, ) - if dataflow_job_name and self.dataflow_config.location: + if self.dataflow_job_name and self.dataflow_config.location: multiple_jobs = self.dataflow_config.multiple_jobs or False DataflowJobLink.persist( self, @@ -585,7 +580,7 @@ def execute(self, context: Context): self.dataflow_job_id, ) self.dataflow_hook.wait_for_done( - job_name=dataflow_job_name, + job_name=self.dataflow_job_name, location=self.dataflow_config.location, job_id=self.dataflow_job_id, multiple_jobs=multiple_jobs, @@ -594,11 +589,65 @@ def execute(self, context: Context): return {"dataflow_job_id": self.dataflow_job_id} else: self.beam_hook.start_java_pipeline( - variables=pipeline_options, + variables=self.pipeline_options, jar=self.jar, job_class=self.job_class, - process_line_callback=process_line_callback, + process_line_callback=self.process_line_callback, + ) + + async def execute_async(self, context: Context): + # Creating a new event loop to manage I/O operations asynchronously + loop = asyncio.get_event_loop() + if self.jar.lower().startswith("gs://"): + gcs_hook = GCSHook(self.gcp_conn_id) + # Running synchronous `enter_context()` method in a separate + # thread using the default executor `None`. The `run_in_executor()` function returns the + # file object, which is created using gcs function `provide_file()`, asynchronously. + # This means we can perform asynchronous operations with this file. + create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar) + tmp_gcs_file: IO[str] = await loop.run_in_executor( + None, contextlib.ExitStack().enter_context, create_tmp_file_call + ) + self.jar = tmp_gcs_file.name + + if self.is_dataflow and self.dataflow_hook: + DataflowJobLink.persist( + self, + context, + self.dataflow_config.project_id, + self.dataflow_config.location, + self.dataflow_job_id, + ) + with self.dataflow_hook.provide_authorized_gcloud(): + self.pipeline_options["jobName"] = self.dataflow_job_name + self.defer( + trigger=BeamJavaPipelineTrigger( + variables=self.pipeline_options, + jar=self.jar, + job_class=self.job_class, + runner=self.runner, + check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun, + project_id=self.dataflow_config.project_id, + location=self.dataflow_config.location, + job_name=self.dataflow_job_name, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.dataflow_config.impersonation_chain, + poll_sleep=self.dataflow_config.poll_sleep, + cancel_timeout=self.dataflow_config.cancel_timeout, + ), + method_name="execute_complete", ) + else: + self.defer( + trigger=BeamJavaPipelineTrigger( + variables=self.pipeline_options, + jar=self.jar, + job_class=self.job_class, + runner=self.runner, + check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun, + ), + method_name="execute_complete", + ) def on_kill(self) -> None: if self.dataflow_hook and self.dataflow_job_id: diff --git a/airflow/providers/apache/beam/triggers/beam.py b/airflow/providers/apache/beam/triggers/beam.py index 0d201cd8c9a4d..4caa46d1e50f0 100644 --- a/airflow/providers/apache/beam/triggers/beam.py +++ b/airflow/providers/apache/beam/triggers/beam.py @@ -16,15 +16,33 @@ # under the License. from __future__ import annotations -from typing import Any, AsyncIterator +import asyncio +import warnings +from typing import Any, AsyncIterator, Sequence +from google.cloud.dataflow_v1beta3 import ListJobsRequest + +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook +from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook from airflow.triggers.base import BaseTrigger, TriggerEvent -class BeamPipelineTrigger(BaseTrigger): +class BeamPipelineBaseTrigger(BaseTrigger): + """Base class for Beam Pipeline Triggers.""" + + @staticmethod + def _get_async_hook(*args, **kwargs) -> BeamAsyncHook: + return BeamAsyncHook(*args, **kwargs) + + @staticmethod + def _get_sync_dataflow_hook(**kwargs) -> AsyncDataflowHook: + return AsyncDataflowHook(**kwargs) + + +class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger): """ - Trigger to perform checking the pipeline status until it reaches terminate state. + Trigger to perform checking the Python pipeline status until it reaches terminate state. :param variables: Variables passed to the pipeline. :param py_file: Path to the python file to execute. @@ -35,12 +53,10 @@ class BeamPipelineTrigger(BaseTrigger): :param py_requirements: Additional python package(s) to install. If a value is passed to this parameter, a new virtual environment has been created with additional packages installed. - You could also install the apache-beam package if it is not installed on your system, or you want to use a different version. :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. See virtualenv documentation for more information. - This option is only relevant if the ``py_requirements`` parameter is not None. :param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used. Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner. @@ -68,9 +84,9 @@ def __init__( self.runner = runner def serialize(self) -> tuple[str, dict[str, Any]]: - """Serialize BeamPipelineTrigger arguments and classpath.""" + """Serialize BeamPythonPipelineTrigger arguments and classpath.""" return ( - "airflow.providers.apache.beam.triggers.beam.BeamPipelineTrigger", + "airflow.providers.apache.beam.triggers.beam.BeamPythonPipelineTrigger", { "variables": self.variables, "py_file": self.py_file, @@ -84,7 +100,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current pipeline status and yields a TriggerEvent.""" - hook = self._get_async_hook() + hook = self._get_async_hook(runner=self.runner) try: return_code = await hook.start_python_pipeline_async( variables=self.variables, @@ -109,5 +125,146 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] yield TriggerEvent({"status": "error", "message": "Operation failed"}) return - def _get_async_hook(self) -> BeamAsyncHook: - return BeamAsyncHook(runner=self.runner) + +class BeamJavaPipelineTrigger(BeamPipelineBaseTrigger): + """ + Trigger to perform checking the Java pipeline status until it reaches terminate state. + + :param variables: Variables passed to the job. + :param jar: Name of the jar for the pipeline. + :param job_class: Optional. Name of the java class for the pipeline. + :param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used. + Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner. + See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType` + See: https://beam.apache.org/documentation/runners/capability-matrix/ + :param check_if_running: Optional. Before running job, validate that a previous run is not in process. + :param project_id: Optional. The Google Cloud project ID in which to start a job. + :param location: Optional. Job location. + :param job_name: Optional. The 'jobName' to use when executing the Dataflow job. + :param gcp_conn_id: Optional. The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional. GCP service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param poll_sleep: Optional. The time in seconds to sleep between polling GCP for the dataflow job status. + Default value is 10s. + :param cancel_timeout: Optional. How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. Default value is 300s. + """ + + def __init__( + self, + variables: dict, + jar: str, + job_class: str | None = None, + runner: str = "DirectRunner", + check_if_running: bool = False, + project_id: str | None = None, + location: str | None = None, + job_name: str | None = None, + gcp_conn_id: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + poll_sleep: int = 10, + cancel_timeout: int | None = None, + ): + super().__init__() + self.variables = variables + self.jar = jar + self.job_class = job_class + self.runner = runner + self.check_if_running = check_if_running + self.project_id = project_id + self.location = location + self.job_name = job_name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.poll_sleep = poll_sleep + self.cancel_timeout = cancel_timeout + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize BeamJavaPipelineTrigger arguments and classpath.""" + return ( + "airflow.providers.apache.beam.triggers.beam.BeamJavaPipelineTrigger", + { + "variables": self.variables, + "jar": self.jar, + "job_class": self.job_class, + "runner": self.runner, + "check_if_running": self.check_if_running, + "project_id": self.project_id, + "location": self.location, + "job_name": self.job_name, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "poll_sleep": self.poll_sleep, + "cancel_timeout": self.cancel_timeout, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] + """Get current Java pipeline status and yields a TriggerEvent.""" + hook = self._get_async_hook(runner=self.runner) + + return_code = 0 + if self.check_if_running: + dataflow_hook = self._get_sync_dataflow_hook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.cancel_timeout, + ) + is_running = True + while is_running: + try: + jobs = await dataflow_hook.list_jobs( + project_id=self.project_id, + location=self.location, + jobs_filter=ListJobsRequest.Filter.ACTIVE, + ) + is_running = bool([job async for job in jobs if job.name == self.job_name]) + except Exception as e: + self.log.exception(f"Exception occurred while requesting jobs with name {self.job_name}") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + if is_running: + await asyncio.sleep(self.poll_sleep) + try: + return_code = await hook.start_java_pipeline_async( + variables=self.variables, jar=self.jar, job_class=self.job_class + ) + except Exception as e: + self.log.exception("Exception occurred while starting the Java pipeline") + yield TriggerEvent({"status": "error", "message": str(e)}) + + if return_code == 0: + yield TriggerEvent( + { + "status": "success", + "message": "Pipeline has finished SUCCESSFULLY", + } + ) + else: + yield TriggerEvent({"status": "error", "message": "Operation failed"}) + return + + +class BeamPipelineTrigger(BeamPythonPipelineTrigger): + """ + Trigger to perform checking the Python pipeline status until it reaches terminate state. + + This class is deprecated. Please use + :class:`airflow.providers.apache.beam.triggers.beam.BeamPythonPipelineTrigger` + instead. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "`BeamPipelineTrigger` is deprecated. Please use `BeamPythonPipelineTrigger`.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 6c051ce6190a1..486f9225a2860 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -27,9 +27,10 @@ import uuid import warnings from copy import deepcopy -from typing import Any, Callable, Generator, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, TypeVar, cast from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView +from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest from googleapiclient.discovery import build from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -42,6 +43,10 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timeout import timeout +if TYPE_CHECKING: + from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager + + # This is the default location # https://cloud.google.com/dataflow/pipelines/specifying-exec-params DEFAULT_DATAFLOW_LOCATION = "us-central1" @@ -219,7 +224,7 @@ def __init__( def is_job_running(self) -> bool: """ - Helper method to check if jos is still running in dataflow. + Helper method to check if job is still running in dataflow. :return: True if job is running. """ @@ -1313,3 +1318,38 @@ async def get_job_status( ) state = job.current_state return state + + async def list_jobs( + self, + jobs_filter: int | None = None, + project_id: str | None = PROVIDE_PROJECT_ID, + location: str | None = DEFAULT_DATAFLOW_LOCATION, + page_size: int | None = None, + page_token: str | None = None, + ) -> ListJobsAsyncPager: + """List jobs. + + For detail see: + https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.ListJobsRequest + + :param jobs_filter: Optional. This field filters out and returns jobs in the specified job state. + :param project_id: Optional. The Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param location: Optional. The location of the Dataflow job (for example europe-west1). + :param page_size: Optional. If there are many jobs, limit response to at most this many. + :param page_token: Optional. Set this to the 'next_page_token' field of a previous response to request + additional results in a long list. + """ + project_id = project_id or (await self.get_project_id()) + client = await self.initialize_client(JobsV1Beta3AsyncClient) + request: ListJobsRequest = ListJobsRequest( + { + "project_id": project_id, + "location": location, + "filter": jobs_filter, + "page_size": page_size, + "page_token": page_token, + } + ) + page_result: ListJobsAsyncPager = await client.list_jobs(request=request) + return page_result diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst index 3851cb8fa8122..905bb5790bca5 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst @@ -87,6 +87,14 @@ Here is an example of creating and running a pipeline in Java with jar stored on :start-after: [START howto_operator_start_java_job_jar_on_gcs] :end-before: [END howto_operator_start_java_job_jar_on_gcs] +Here is an example of creating and running a pipeline in Java with jar stored on GCS in deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_start_java_job_jar_on_gcs_deferrable] + :end-before: [END howto_operator_start_java_job_jar_on_gcs_deferrable] + Here is an example of creating and running a pipeline in Java with jar stored on local file system: .. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py diff --git a/tests/providers/apache/beam/hooks/test_beam.py b/tests/providers/apache/beam/hooks/test_beam.py index ce33b683e2595..10e2808a09e31 100644 --- a/tests/providers/apache/beam/hooks/test_beam.py +++ b/tests/providers/apache/beam/hooks/test_beam.py @@ -42,12 +42,19 @@ GO_FILE = "/path/to/file.go" DEFAULT_RUNNER = "DirectRunner" BEAM_STRING = "airflow.providers.apache.beam.hooks.beam.{}" +BEAM_VARIABLES = {"output": "gs://test/output", "labels": {"foo": "bar"}} BEAM_VARIABLES_PY = {"output": "gs://test/output", "labels": {"foo": "bar"}} BEAM_VARIABLES_JAVA = { "output": "gs://test/output", "labels": {"foo": "bar"}, } +BEAM_VARIABLES_JAVA_STRING_LABELS = { + "output": "gs://test/output", + "labels": '{"foo":"bar"}', +} BEAM_VARIABLES_GO = {"output": "gs://test/output", "labels": {"foo": "bar"}} +PIPELINE_COMMAND_PREFIX = ["a", "b", "c"] +WORKING_DIRECTORY = "test_wd" APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG = f""""\ Dataflow SDK version: 2.14.0 @@ -418,6 +425,25 @@ def test_beam_options_to_args(self, options, expected_args): class TestBeamAsyncHook: + @pytest.mark.asyncio + @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.run_beam_command_async") + async def test_start_pipline_async(self, mock_runner): + expected_cmd = [ + *PIPELINE_COMMAND_PREFIX, + f"--runner={DEFAULT_RUNNER}", + *beam_options_to_args(BEAM_VARIABLES), + ] + hook = BeamAsyncHook(runner=DEFAULT_RUNNER) + await hook.start_pipeline_async( + variables=BEAM_VARIABLES, + command_prefix=PIPELINE_COMMAND_PREFIX, + working_directory=WORKING_DIRECTORY, + ) + + mock_runner.assert_called_once_with( + cmd=expected_cmd, working_directory=WORKING_DIRECTORY, log=hook.log + ) + @pytest.mark.asyncio @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.run_beam_command_async") @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook._create_tmp_dir") @@ -583,3 +609,21 @@ async def test_start_python_pipeline_with_empty_py_requirements_and_without_syst ) mock_runner.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "job_class, command_prefix", + [ + (JOB_CLASS, ["java", "-cp", JAR_FILE, JOB_CLASS]), + (None, ["java", "-jar", JAR_FILE]), + ], + ) + @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_pipeline_async") + async def test_start_java_pipeline_async(self, mock_start_pipeline, job_class, command_prefix): + variables = copy.deepcopy(BEAM_VARIABLES_JAVA) + hook = BeamAsyncHook(runner=DEFAULT_RUNNER) + await hook.start_java_pipeline_async(variables=variables, jar=JAR_FILE, job_class=job_class) + + mock_start_pipeline.assert_called_once_with( + variables=BEAM_VARIABLES_JAVA_STRING_LABELS, command_prefix=command_prefix + ) diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py index 8b6f57ccccc36..538c2417a0e38 100644 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -24,11 +24,12 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.apache.beam.operators.beam import ( + BeamBasePipelineOperator, BeamRunGoPipelineOperator, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator, ) -from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger +from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration from airflow.version import version @@ -60,6 +61,35 @@ "labels": {"foo": "bar", "airflow-version": TEST_VERSION}, } TEST_IMPERSONATION_ACCOUNT = "test@impersonation.com" +BEAM_OPERATOR_PATH = "airflow.providers.apache.beam.operators.beam.{}" + + +class TestBeamBasePipelineOperator: + def setup_method(self): + self.operator = BeamBasePipelineOperator( + task_id=TASK_ID, + runner=DEFAULT_RUNNER, + ) + + def test_async_execute_should_throw_exception(self): + """Tests that an AirflowException is raised in case of error event""" + + with pytest.raises(AirflowException): + self.operator.execute_complete( + context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} + ) + + def test_async_execute_logging_should_execute_successfully(self): + """Asserts that logging occurs as expected""" + + with mock.patch.object(self.operator.log, "info") as mock_log_info: + self.operator.execute_complete( + context=mock.MagicMock(), + event={"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}, + ) + mock_log_info.assert_called_with( + "%s completed with response %s ", TASK_ID, "Pipeline has finished SUCCESSFULLY" + ) class TestBeamRunPythonPipelineOperator: @@ -82,8 +112,8 @@ def test_init(self): assert self.operator.default_pipeline_options == PY_DEFAULT_OPTIONS assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_exec_direct_runner(self, gcs_hook, beam_hook_mock): """Test BeamHook is created and the right args are passed to start_python_workflow. @@ -111,10 +141,10 @@ def test_exec_direct_runner(self, gcs_hook, beam_hook_mock): process_line_callback=None, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock): """Test DataflowHook is created and the right args are passed to start_python_dataflow. @@ -164,10 +194,10 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock ) dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): self.operator.runner = "DataflowRunner" dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job @@ -178,9 +208,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_on_kill_direct_runner(self, _, dataflow_mock, __): dataflow_cancel_job = dataflow_mock.return_value.cancel_job self.operator.execute(None) @@ -207,8 +237,8 @@ def test_init(self): assert self.operator.jar == JAR_FILE assert self.operator.pipeline_options == ADDITIONAL_OPTIONS - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_exec_direct_runner(self, gcs_hook, beam_hook_mock): """Test BeamHook is created and the right args are passed to start_java_workflow. @@ -226,10 +256,10 @@ def test_exec_direct_runner(self, gcs_hook, beam_hook_mock): process_line_callback=None, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock): """Test DataflowHook is created and the right args are passed to start_java_dataflow. @@ -274,10 +304,10 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock process_line_callback=mock.ANY, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): self.operator.runner = "DataflowRunner" dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False @@ -289,9 +319,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_on_kill_direct_runner(self, _, dataflow_mock, __): dataflow_cancel_job = dataflow_mock.return_value.cancel_job self.operator.execute(None) @@ -386,8 +416,8 @@ def test_init_with_both_go_file_and_launcher_binary_raises(self): "tempfile.TemporaryDirectory", return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/apache-beam-go")), ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_exec_direct_runner_with_gcs_go_file(self, gcs_hook, beam_hook_mock, _): """Test BeamHook is created and the right args are passed to start_go_workflow. @@ -413,8 +443,8 @@ def test_exec_direct_runner_with_gcs_go_file(self, gcs_hook, beam_hook_mock, _): should_init_module=True, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch("tempfile.TemporaryDirectory") def test_exec_direct_runner_with_gcs_launcher_binary( self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, tmp_path @@ -468,7 +498,7 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str) process_line_callback=None, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch("airflow.providers.google.go_module_utils.init_module") def test_exec_direct_runner_with_local_go_file(self, init_module, beam_hook_mock): """ @@ -490,7 +520,7 @@ def test_exec_direct_runner_with_local_go_file(self, init_module, beam_hook_mock should_init_module=False, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) def test_exec_direct_runner_with_local_launcher_binary(self, mock_beam_hook): """ Test start_go_pipeline_with_binary is called with a local launcher binary. @@ -513,14 +543,14 @@ def test_exec_direct_runner_with_local_launcher_binary(self, mock_beam_hook): process_line_callback=None, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch( "tempfile.TemporaryDirectory", return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/apache-beam-go")), ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_exec_dataflow_runner_with_go_file( self, gcs_hook, dataflow_hook_mock, beam_hook_mock, _, persist_link_mock ): @@ -575,11 +605,11 @@ def test_exec_dataflow_runner_with_go_file( ) dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("tempfile.TemporaryDirectory") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("tempfile.TemporaryDirectory")) def test_exec_dataflow_runner_with_launcher_binary_and_worker_binary( self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, mock_dataflow_hook, mock_persist_link, tmp_path ): @@ -672,10 +702,10 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str) project_id=dataflow_config.project_id, ) - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): self.operator.runner = "DataflowRunner" dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job @@ -686,9 +716,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_on_kill_direct_runner(self, _, dataflow_mock, __): dataflow_cancel_job = dataflow_mock.return_value.cancel_job self.operator.execute(None) @@ -717,59 +747,161 @@ def test_init(self): assert self.operator.default_pipeline_options == DEFAULT_OPTIONS assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_async_execute_should_execute_successfully(self, gcs_hook, beam_hook_mock): """ - Asserts that a task is deferred and the BeamPipelineTrigger will be fired + Asserts that a task is deferred and the BeamPythonPipelineTrigger will be fired when the BeamRunPythonPipelineOperator is executed in deferrable mode when deferrable=True. """ with pytest.raises(TaskDeferred) as exc: self.operator.execute(context=mock.MagicMock()) - assert isinstance(exc.value.trigger, BeamPipelineTrigger), "Trigger is not a BeamPipelineTrigger" + assert isinstance( + exc.value.trigger, BeamPythonPipelineTrigger + ), "Trigger is not a BeamPythonPipelineTrigger" - def test_async_execute_should_throw_exception(self): - """Tests that an AirflowException is raised in case of error event""" + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock): + """ + Test BeamHook is created and the right args are passed to + start_python_workflow when executing direct runner. + """ + gcs_provide_file = gcs_hook.return_value.provide_file + with pytest.raises(TaskDeferred): + self.operator.execute(context=mock.MagicMock()) + beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER) + gcs_provide_file.assert_called_once_with(object_url=PY_FILE) - with pytest.raises(AirflowException): - self.operator.execute_complete( - context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} - ) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock): + """ + Test DataflowHook is created and the right args are passed to + start_python_dataflow when executing Dataflow runner. + """ - def test_async_execute_logging_should_execute_successfully(self): - """Asserts that logging occurs as expected""" + dataflow_config = DataflowConfiguration(impersonation_chain=TEST_IMPERSONATION_ACCOUNT) + self.operator.runner = "DataflowRunner" + self.operator.dataflow_config = dataflow_config + gcs_provide_file = gcs_hook.return_value.provide_file + magic_mock = mock.MagicMock() + with pytest.raises(TaskDeferred): + self.operator.execute(context=magic_mock) - with mock.patch.object(self.operator.log, "info") as mock_log_info: - self.operator.execute_complete( - context=mock.MagicMock(), - event={"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}, - ) - mock_log_info.assert_called_with( - "%s completed with response %s ", TASK_ID, "Pipeline has finished SUCCESSFULLY" + job_name = dataflow_hook_mock.build_dataflow_job_name.return_value + dataflow_hook_mock.assert_called_once_with( + gcp_conn_id=dataflow_config.gcp_conn_id, + poll_sleep=dataflow_config.poll_sleep, + impersonation_chain=dataflow_config.impersonation_chain, + drain_pipeline=dataflow_config.drain_pipeline, + cancel_timeout=dataflow_config.cancel_timeout, + wait_until_finished=dataflow_config.wait_until_finished, + ) + expected_options = { + "project": dataflow_hook_mock.return_value.project_id, + "job_name": job_name, + "staging_location": "gs://test/staging", + "output": "gs://test/output", + "labels": {"foo": "bar", "airflow-version": TEST_VERSION}, + "region": "us-central1", + "impersonate_service_account": TEST_IMPERSONATION_ACCOUNT, + } + gcs_provide_file.assert_called_once_with(object_url=PY_FILE) + persist_link_mock.assert_called_once_with( + self.operator, + magic_mock, + expected_options["project"], + expected_options["region"], + self.operator.dataflow_job_id, + ) + beam_hook_mock.return_value.start_python_pipeline.assert_not_called() + dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() + + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): + self.operator.runner = "DataflowRunner" + dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job + with pytest.raises(TaskDeferred): + self.operator.execute(context=mock.MagicMock()) + self.operator.dataflow_job_id = JOB_ID + self.operator.on_kill() + dataflow_cancel_job.assert_called_once_with( + job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_on_kill_direct_runner(self, _, dataflow_mock, __): + dataflow_cancel_job = dataflow_mock.return_value.cancel_job + with pytest.raises(TaskDeferred): + self.operator.execute(mock.MagicMock()) + self.operator.on_kill() + dataflow_cancel_job.assert_not_called() + + +class TestBeamRunJavaPipelineOperatorAsync: + def setup_method(self): + self.operator = BeamRunJavaPipelineOperator( + task_id=TASK_ID, + jar=JAR_FILE, + job_class=JOB_CLASS, + default_pipeline_options=DEFAULT_OPTIONS, + pipeline_options=ADDITIONAL_OPTIONS, + deferrable=True, + ) + + def test_init(self): + """Test BeamRunJavaPipelineOperator instance is properly initialized.""" + assert self.operator.task_id == TASK_ID + assert self.operator.jar == JAR_FILE + assert self.operator.runner == DEFAULT_RUNNER + assert self.operator.job_class == JOB_CLASS + assert self.operator.default_pipeline_options == DEFAULT_OPTIONS + assert self.operator.pipeline_options == ADDITIONAL_OPTIONS + + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_async_execute_should_execute_successfully(self, gcs_hook, beam_hook_mock): + """ + Asserts that a task is deferred and the BeamJavaPipelineTrigger will be fired + when the BeamRunPythonPipelineOperator is executed in deferrable mode when deferrable=True. + """ + with pytest.raises(TaskDeferred) as exc: + self.operator.execute(context=mock.MagicMock()) + + assert isinstance( + exc.value.trigger, BeamJavaPipelineTrigger + ), "Trigger is not a BeamPJavaPipelineTrigger" + + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock): """ Test BeamHook is created and the right args are passed to - start_python_workflow when executing direct runner. + start_java_pipeline when executing direct runner. """ gcs_provide_file = gcs_hook.return_value.provide_file with pytest.raises(TaskDeferred): self.operator.execute(context=mock.MagicMock()) beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER) - gcs_provide_file.assert_called_once_with(object_url=PY_FILE) + gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock): """ Test DataflowHook is created and the right args are passed to - start_python_dataflow when executing Dataflow runner. + start_java_pipeline when executing Dataflow runner. """ dataflow_config = DataflowConfiguration(impersonation_chain=TEST_IMPERSONATION_ACCOUNT) @@ -798,7 +930,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock "region": "us-central1", "impersonate_service_account": TEST_IMPERSONATION_ACCOUNT, } - gcs_provide_file.assert_called_once_with(object_url=PY_FILE) + gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) persist_link_mock.assert_called_once_with( self.operator, magic_mock, @@ -809,10 +941,10 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock beam_hook_mock.return_value.start_python_pipeline.assert_not_called() dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist") - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): self.operator.runner = "DataflowRunner" dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job @@ -824,9 +956,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id ) - @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook") - @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook") + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) def test_on_kill_direct_runner(self, _, dataflow_mock, __): dataflow_cancel_job = dataflow_mock.return_value.cancel_job with pytest.raises(TaskDeferred): diff --git a/tests/providers/apache/beam/triggers/test_beam.py b/tests/providers/apache/beam/triggers/test_beam.py index 82e56ff3ec639..6bd1b4bc6647e 100644 --- a/tests/providers/apache/beam/triggers/test_beam.py +++ b/tests/providers/apache/beam/triggers/test_beam.py @@ -20,17 +20,21 @@ import pytest -from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger +from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger from airflow.triggers.base import TriggerEvent -HOOK_STATUS_STR = "airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_python_pipeline_async" -CLASSPATH = "airflow.providers.apache.beam.triggers.beam.BeamPipelineTrigger" +HOOK_STATUS_STR_PYTHON = "airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_python_pipeline_async" +HOOK_STATUS_STR_JAVA = "airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_java_pipeline_async" +CLASSPATH_PYTHON = "airflow.providers.apache.beam.triggers.beam.BeamPythonPipelineTrigger" +CLASSPATH_JAVA = "airflow.providers.apache.beam.triggers.beam.BeamJavaPipelineTrigger" TASK_ID = "test_task" LOCATION = "test-location" INSTANCE_NAME = "airflow-test-instance" INSTANCE = {"type": "BASIC", "displayName": INSTANCE_NAME} PROJECT_ID = "test_project_id" +TEST_GCP_CONN_ID = "test_gcp_conn_id" +TEST_IMPERSONATION_CHAIN = "A,B,C" TEST_VARIABLES = {"output": "gs://bucket_test/output", "labels": {"airflow-version": "v2-7-0-dev0"}} TEST_PY_FILE = "apache_beam.examples.wordcount" TEST_PY_OPTIONS: list[str] = [] @@ -38,11 +42,17 @@ TEST_PY_REQUIREMENTS = ["apache-beam[gcp]==2.46.0"] TEST_PY_PACKAGES = False TEST_RUNNER = "DirectRunner" +TEST_JAR_FILE = "example.jar" +TEST_JOB_CLASS = "TestClass" +TEST_CHECK_IF_RUNNING = False +TEST_JOB_NAME = "test_job_name" +TEST_POLL_SLEEP = 10 +TEST_CANCEL_TIMEOUT = 300 @pytest.fixture -def trigger(): - return BeamPipelineTrigger( +def python_trigger(): + return BeamPythonPipelineTrigger( variables=TEST_VARIABLES, py_file=TEST_PY_FILE, py_options=TEST_PY_OPTIONS, @@ -53,14 +63,32 @@ def trigger(): ) -class TestBeamPipelineTrigger: - def test_beam_trigger_serialization_should_execute_successfully(self, trigger): +@pytest.fixture +def java_trigger(): + return BeamJavaPipelineTrigger( + variables=TEST_VARIABLES, + jar=TEST_JAR_FILE, + job_class=TEST_JOB_CLASS, + runner=TEST_RUNNER, + check_if_running=TEST_CHECK_IF_RUNNING, + project_id=PROJECT_ID, + location=LOCATION, + job_name=TEST_JOB_NAME, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_sleep=TEST_POLL_SLEEP, + cancel_timeout=TEST_CANCEL_TIMEOUT, + ) + + +class TestBeamPythonPipelineTrigger: + def test_beam_trigger_serialization_should_execute_successfully(self, python_trigger): """ - Asserts that the BeamPipelineTrigger correctly serializes its arguments + Asserts that the BeamPythonPipelineTrigger correctly serializes its arguments and classpath. """ - classpath, kwargs = trigger.serialize() - assert classpath == CLASSPATH + classpath, kwargs = python_trigger.serialize() + assert classpath == CLASSPATH_PYTHON assert kwargs == { "variables": TEST_VARIABLES, "py_file": TEST_PY_FILE, @@ -72,36 +100,118 @@ def test_beam_trigger_serialization_should_execute_successfully(self, trigger): } @pytest.mark.asyncio - @mock.patch(HOOK_STATUS_STR) - async def test_beam_trigger_on_success_should_execute_successfully(self, mock_pipeline_status, trigger): + @mock.patch(HOOK_STATUS_STR_PYTHON) + async def test_beam_trigger_on_success_should_execute_successfully( + self, mock_pipeline_status, python_trigger + ): """ - Tests the BeamPipelineTrigger only fires once the job execution reaches a successful state. + Tests the BeamPythonPipelineTrigger only fires once the job execution reaches a successful state. """ mock_pipeline_status.return_value = 0 - generator = trigger.run() + generator = python_trigger.run() actual = await generator.asend(None) assert TriggerEvent({"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}) == actual @pytest.mark.asyncio - @mock.patch(HOOK_STATUS_STR) - async def test_beam_trigger_error_should_execute_successfully(self, mock_pipeline_status, trigger): + @mock.patch(HOOK_STATUS_STR_PYTHON) + async def test_beam_trigger_error_should_execute_successfully(self, mock_pipeline_status, python_trigger): """ - Test that BeamPipelineTrigger fires the correct event in case of an error. + Test that BeamPythonPipelineTrigger fires the correct event in case of an error. """ mock_pipeline_status.return_value = 1 - generator = trigger.run() + generator = python_trigger.run() actual = await generator.asend(None) assert TriggerEvent({"status": "error", "message": "Operation failed"}) == actual @pytest.mark.asyncio - @mock.patch(HOOK_STATUS_STR) - async def test_beam_trigger_exception_should_execute_successfully(self, mock_pipeline_status, trigger): + @mock.patch(HOOK_STATUS_STR_PYTHON) + async def test_beam_trigger_exception_should_execute_successfully( + self, mock_pipeline_status, python_trigger + ): """ - Test that BeamPipelineTrigger fires the correct event in case of an error. + Test that BeamPythonPipelineTrigger fires the correct event in case of an error. """ mock_pipeline_status.side_effect = Exception("Test exception") - generator = trigger.run() + generator = python_trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual + + +class TestBeamJavaPipelineTrigger: + def test_beam_trigger_serialization_should_execute_successfully(self, java_trigger): + """ + Asserts that the BeamJavaPipelineTrigger correctly serializes its arguments + and classpath. + """ + classpath, kwargs = java_trigger.serialize() + assert classpath == CLASSPATH_JAVA + assert kwargs == { + "variables": TEST_VARIABLES, + "jar": TEST_JAR_FILE, + "job_class": TEST_JOB_CLASS, + "runner": TEST_RUNNER, + "check_if_running": TEST_CHECK_IF_RUNNING, + "project_id": PROJECT_ID, + "location": LOCATION, + "job_name": TEST_JOB_NAME, + "gcp_conn_id": TEST_GCP_CONN_ID, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, + "poll_sleep": TEST_POLL_SLEEP, + "cancel_timeout": TEST_CANCEL_TIMEOUT, + } + + @pytest.mark.asyncio + @mock.patch(HOOK_STATUS_STR_JAVA) + async def test_beam_trigger_on_success_should_execute_successfully( + self, mock_pipeline_status, java_trigger + ): + """ + Tests the BeamJavaPipelineTrigger only fires once the job execution reaches a successful state. + """ + mock_pipeline_status.return_value = 0 + generator = java_trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}) == actual + + @pytest.mark.asyncio + @mock.patch(HOOK_STATUS_STR_JAVA) + async def test_beam_trigger_error_should_execute_successfully(self, mock_pipeline_status, java_trigger): + """ + Test that BeamJavaPipelineTrigger fires the correct event in case of an error. + """ + mock_pipeline_status.return_value = 1 + + generator = java_trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "Operation failed"}) == actual + + @pytest.mark.asyncio + @mock.patch(HOOK_STATUS_STR_JAVA) + async def test_beam_trigger_exception_should_execute_successfully( + self, mock_pipeline_status, java_trigger + ): + """ + Test that BeamJavaPipelineTrigger fires the correct event in case of an error. + """ + mock_pipeline_status.side_effect = Exception("Test exception") + + generator = java_trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.list_jobs") + async def test_beam_trigger_exception_list_jobs_should_execute_successfully( + self, mock_list_jobs, java_trigger + ): + """ + Test that BeamJavaPipelineTrigger fires the correct event in case of an error. + """ + mock_list_jobs.side_effect = Exception("Test exception") + + java_trigger.check_if_running = True + generator = java_trigger.run() actual = await generator.asend(None) assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 22d4051c43f93..e26713502abc7 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -27,7 +27,7 @@ from uuid import UUID import pytest -from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView +from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView, ListJobsRequest from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamHook, run_beam_command @@ -89,6 +89,7 @@ DATAFLOW_STRING = "airflow.providers.google.cloud.hooks.dataflow.{}" TEST_PROJECT = "test-project" TEST_JOB_ID = "test-job-id" +TEST_JOBS_FILTER = ListJobsRequest.Filter.ACTIVE TEST_LOCATION = "custom-location" DEFAULT_PY_INTERPRETER = "python3" TEST_FLEX_PARAMETERS = { @@ -1949,7 +1950,7 @@ def hook(self): ) @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.initialize_client") + @mock.patch(DATAFLOW_STRING.format("AsyncDataflowHook.initialize_client")) async def test_get_job(self, initialize_client_mock, hook, make_mock_awaitable): client = initialize_client_mock.return_value make_mock_awaitable(client.get_job, None) @@ -1972,3 +1973,27 @@ async def test_get_job(self, initialize_client_mock, hook, make_mock_awaitable): client.get_job.assert_called_once_with( request=request, ) + + @pytest.mark.asyncio + @mock.patch(DATAFLOW_STRING.format("AsyncDataflowHook.initialize_client")) + async def test_list_jobs(self, initialize_client_mock, hook, make_mock_awaitable): + client = initialize_client_mock.return_value + make_mock_awaitable(client.get_job, None) + + await hook.list_jobs( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + jobs_filter=TEST_JOBS_FILTER, + ) + + request = ListJobsRequest( + { + "project_id": TEST_PROJECT_ID, + "location": TEST_LOCATION, + "filter": TEST_JOBS_FILTER, + "page_size": None, + "page_token": None, + } + ) + initialize_client_mock.assert_called_once() + client.list_jobs.assert_called_once_with(request=request) diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py index 53b33b89e9571..12047bae5b404 100644 --- a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py @@ -48,12 +48,12 @@ DAG_ID = "dataflow_native_java" BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" -PUBLIC_BUCKET = "system-tests-resources" +PUBLIC_BUCKET = "airflow-system-tests-resources" JAR_FILE_NAME = "word-count-beam-bundled-0.1.jar" -REMOTE_JAR_FILE_PATH = f"{DAG_ID}/{JAR_FILE_NAME}" +REMOTE_JAR_FILE_PATH = f"dataflow/java/{JAR_FILE_NAME}" GCS_OUTPUT = f"gs://{BUCKET_NAME}" -GCS_JAR = f"gs://{PUBLIC_BUCKET}/{REMOTE_JAR_FILE_PATH}" +GCS_JAR = f"gs://{PUBLIC_BUCKET}/dataflow/java/{JAR_FILE_NAME}" LOCATION = "europe-west3" with DAG( @@ -105,11 +105,38 @@ ) # [END howto_operator_start_java_job_jar_on_gcs] + # [START howto_operator_start_java_job_jar_on_gcs_deferrable] + start_java_deferrable = BeamRunJavaPipelineOperator( + runner=BeamRunnerType.DataflowRunner, + task_id="start-java-job-deferrable", + jar=GCS_JAR, + pipeline_options={ + "output": GCS_OUTPUT, + }, + job_class="org.apache.beam.examples.WordCount", + dataflow_config={ + "check_if_running": CheckJobRunning.WaitForRun, + "location": LOCATION, + "poll_sleep": 10, + "append_job_name": False, + }, + deferrable=True, + ) + # [END howto_operator_start_java_job_jar_on_gcs_deferrable] + delete_bucket = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE ) - create_bucket >> download_file >> [start_java_job_local, start_java_job] >> delete_bucket + ( + # TEST SETUP + create_bucket + >> download_file + # TEST BODY + >> [start_java_job_local, start_java_job, start_java_deferrable] + # TEST TEARDOWN + >> delete_bucket + ) from tests.system.utils.watcher import watcher