Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions airflow/providers/apache/beam/hooks/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,25 @@ async def start_python_pipeline_async(
)
return return_code

async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: str | None = None):
"""
Start Apache Beam Java pipeline.

:param variables: Variables passed to the job.
:param jar: Name of the jar for the pipeline.
:param job_class: Name of the java class for the pipeline.
:return: Beam command execution return code.
"""
if "labels" in variables:
variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))

command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar]
return_code = await self.start_pipeline_async(
variables=variables,
command_prefix=command_prefix,
)
return return_code

async def start_pipeline_async(
self,
variables: dict,
Expand Down
163 changes: 106 additions & 57 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger
from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger
from airflow.providers.google.cloud.hooks.dataflow import (
DataflowHook,
process_line_and_extract_dataflow_job_id_callback,
Expand Down Expand Up @@ -239,6 +239,22 @@ def _init_pipeline_options(
check_job_status_callback,
)

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Execute when the trigger fires - returns immediately.

Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(
"%s completed with response %s ",
self.task_id,
event["message"],
)
return {"dataflow_job_id": self.dataflow_job_id}


class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
"""
Expand Down Expand Up @@ -323,7 +339,7 @@ def __init__(
self.deferrable = deferrable

def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
"""Execute the Apache Beam Python Pipeline."""
(
self.is_dataflow,
self.dataflow_job_name,
Expand Down Expand Up @@ -408,7 +424,7 @@ async def execute_async(self, context: Context):
)
with self.dataflow_hook.provide_authorized_gcloud():
self.defer(
trigger=BeamPipelineTrigger(
trigger=BeamPythonPipelineTrigger(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
Expand All @@ -421,7 +437,7 @@ async def execute_async(self, context: Context):
)
else:
self.defer(
trigger=BeamPipelineTrigger(
trigger=BeamPythonPipelineTrigger(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
Expand All @@ -433,22 +449,6 @@ async def execute_async(self, context: Context):
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Execute when the trigger fires - returns immediately.

Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(
"%s completed with response %s ",
self.task_id,
event["message"],
)
return {"dataflow_job_id": self.dataflow_job_id}

def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id)
Expand Down Expand Up @@ -509,6 +509,7 @@ def __init__(
pipeline_options: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
dataflow_config: DataflowConfiguration | dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(
Expand All @@ -521,61 +522,55 @@ def __init__(
)
self.jar = jar
self.job_class = job_class
self.deferrable = deferrable

def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
"""Execute the Apache Beam Python Pipeline."""
(
is_dataflow,
dataflow_job_name,
pipeline_options,
process_line_callback,
self.is_dataflow,
self.dataflow_job_name,
self.pipeline_options,
self.process_line_callback,
_,
) = self._init_pipeline_options()

if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
if self.deferrable:
asyncio.run(self.execute_async(context))
else:
return self.execute_sync(context)

def execute_sync(self, context: Context):
"""Execute the Apache Beam Pipeline."""
with ExitStack() as exit_stack:
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar))
self.jar = tmp_gcs_file.name

if is_dataflow and self.dataflow_hook:
is_running = False
if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob:
is_running = (
# The reason for disable=no-value-for-parameter is that project_id parameter is
# required but here is not passed, moreover it cannot be passed here.
# This method is wrapped by @_fallback_to_project_id_from_variables decorator which
# fallback project_id value from variables and raise error if project_id is
# defined both in variables and as parameter (here is already defined in variables)
self.dataflow_hook.is_job_dataflow_running(
name=self.dataflow_config.job_name,
variables=pipeline_options,
)
if self.is_dataflow and self.dataflow_hook:
is_running = self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun
while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
# The reason for disable=no-value-for-parameter is that project_id parameter is
# required but here is not passed, moreover it cannot be passed here.
# This method is wrapped by @_fallback_to_project_id_from_variables decorator which
# fallback project_id value from variables and raise error if project_id is
# defined both in variables and as parameter (here is already defined in variables)
is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.dataflow_config.job_name,
variables=self.pipeline_options,
)
while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
# The reason for disable=no-value-for-parameter is that project_id parameter is
# required but here is not passed, moreover it cannot be passed here.
# This method is wrapped by @_fallback_to_project_id_from_variables decorator which
# fallback project_id value from variables and raise error if project_id is
# defined both in variables and as parameter (here is already defined in variables)

is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.dataflow_config.job_name,
variables=pipeline_options,
)

if not is_running:
pipeline_options["jobName"] = dataflow_job_name
self.pipeline_options["jobName"] = self.dataflow_job_name
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
process_line_callback=self.process_line_callback,
)
if dataflow_job_name and self.dataflow_config.location:
if self.dataflow_job_name and self.dataflow_config.location:
multiple_jobs = self.dataflow_config.multiple_jobs or False
DataflowJobLink.persist(
self,
Expand All @@ -585,7 +580,7 @@ def execute(self, context: Context):
self.dataflow_job_id,
)
self.dataflow_hook.wait_for_done(
job_name=dataflow_job_name,
job_name=self.dataflow_job_name,
location=self.dataflow_config.location,
job_id=self.dataflow_job_id,
multiple_jobs=multiple_jobs,
Expand All @@ -594,11 +589,65 @@ def execute(self, context: Context):
return {"dataflow_job_id": self.dataflow_job_id}
else:
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
process_line_callback=self.process_line_callback,
)

async def execute_async(self, context: Context):
# Creating a new event loop to manage I/O operations asynchronously
loop = asyncio.get_event_loop()
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None, contextlib.ExitStack().enter_context, create_tmp_file_call
)
self.jar = tmp_gcs_file.name

if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.pipeline_options["jobName"] = self.dataflow_job_name
self.defer(
trigger=BeamJavaPipelineTrigger(
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
runner=self.runner,
check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
job_name=self.dataflow_job_name,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.dataflow_config.impersonation_chain,
poll_sleep=self.dataflow_config.poll_sleep,
cancel_timeout=self.dataflow_config.cancel_timeout,
),
method_name="execute_complete",
)
else:
self.defer(
trigger=BeamJavaPipelineTrigger(
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
runner=self.runner,
check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun,
),
method_name="execute_complete",
)

def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
Expand Down
Loading