Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import stat
import tempfile
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from functools import partial
Expand All @@ -51,7 +51,6 @@
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
DataflowHook,
process_line_and_extract_dataflow_job_id_callback,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
Expand Down Expand Up @@ -80,6 +79,7 @@ class BeamDataflowMixin(metaclass=ABCMeta):

dataflow_hook: DataflowHook | None
dataflow_config: DataflowConfiguration
dataflow_job_id: str | None
gcp_conn_id: str
dataflow_support_impersonation: bool = True

Expand All @@ -94,16 +94,24 @@ def _set_dataflow(
self,
pipeline_options: dict,
job_name_variable_key: str | None = None,
) -> tuple[str, dict, Callable[[str], None], Callable[[], bool]]:
) -> tuple[str, dict]:
self.dataflow_hook = self.__set_dataflow_hook()
self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id
dataflow_job_name = self.__get_dataflow_job_name()
pipeline_options = self.__get_dataflow_pipeline_options(
pipeline_options, dataflow_job_name, job_name_variable_key
)
process_line_callback = self.__get_dataflow_process_callback()
is_dataflow_job_id_exist_callback = self.__is_dataflow_job_id_exist_callback()
return dataflow_job_name, pipeline_options, process_line_callback, is_dataflow_job_id_exist_callback
return dataflow_job_name, pipeline_options

def _resolve_dataflow_job_id(self, job_name: str | None) -> None:
"""Resolve `dataflow_job_id` by looking up an active job whose name matches `job_name`."""
if self.dataflow_job_id or not self.dataflow_hook or not job_name:
return
self.dataflow_job_id = self.dataflow_hook.fetch_job_id_by_name(
name=job_name,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION,
)

def __set_dataflow_hook(self) -> DataflowHook:
self.dataflow_hook = DataflowHook(
Expand Down Expand Up @@ -144,20 +152,6 @@ def __get_dataflow_pipeline_options(
)
return pipeline_options

def __get_dataflow_process_callback(self) -> Callable[[str], None]:
def set_current_dataflow_job_id(job_id):
self.dataflow_job_id = job_id

return process_line_and_extract_dataflow_job_id_callback(
on_new_job_id_callback=set_current_dataflow_job_id
)

def __is_dataflow_job_id_exist_callback(self) -> Callable[[], bool]:
def is_dataflow_job_id_exist() -> bool:
return True if self.dataflow_job_id else False

return is_dataflow_job_id_exist


class BeamBasePipelineOperator(BaseOperator, BeamDataflowMixin, ABC):
"""
Expand Down Expand Up @@ -240,20 +234,13 @@ def _init_pipeline_options(
self,
format_pipeline_options: bool = False,
job_name_variable_key: str | None = None,
) -> tuple[bool, str | None, dict, Callable[[str], None] | None, Callable[[], bool] | None]:
) -> tuple[bool, str | None, dict]:
self.beam_hook = BeamHook(runner=self.runner)
pipeline_options = self.default_pipeline_options.copy()
process_line_callback: Callable[[str], None] | None = None
is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
dataflow_job_name: str | None = None
if is_dataflow:
(
dataflow_job_name,
pipeline_options,
process_line_callback,
is_dataflow_job_id_exist_callback,
) = self._set_dataflow(
dataflow_job_name, pipeline_options = self._set_dataflow(
pipeline_options=pipeline_options,
job_name_variable_key=job_name_variable_key,
)
Expand All @@ -262,24 +249,11 @@ def _init_pipeline_options(
pipeline_options.update(self.pipeline_options)

if format_pipeline_options:
snake_case_pipeline_options = {
pipeline_options = {
convert_camel_to_snake(key): pipeline_options[key] for key in pipeline_options
}
return (
is_dataflow,
dataflow_job_name,
snake_case_pipeline_options,
process_line_callback,
is_dataflow_job_id_exist_callback,
)

return (
is_dataflow,
dataflow_job_name,
pipeline_options,
process_line_callback,
is_dataflow_job_id_exist_callback,
)
return is_dataflow, dataflow_job_name, pipeline_options

@property
def extra_links_params(self) -> dict[str, Any]:
Expand Down Expand Up @@ -395,8 +369,6 @@ def execute(self, context: Context):
self.is_dataflow,
self.dataflow_job_name,
self.snake_case_pipeline_options,
self.process_line_callback,
self.is_dataflow_job_id_exist_callback,
) = self._init_pipeline_options(format_pipeline_options=True, job_name_variable_key="job_name")
if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
Expand Down Expand Up @@ -455,9 +427,8 @@ def execute_on_dataflow(self, context: Context):
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)
self._resolve_dataflow_job_id(self.dataflow_job_name)

location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION
DataflowJobLink.persist(
Expand Down Expand Up @@ -583,8 +554,6 @@ def execute(self, context: Context):
self.is_dataflow,
self.dataflow_job_name,
self.pipeline_options,
self.process_line_callback,
self.is_dataflow_job_id_exist_callback,
) = self._init_pipeline_options()
if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
Expand Down Expand Up @@ -651,9 +620,8 @@ def execute_on_dataflow(self, context: Context):
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)
self._resolve_dataflow_job_id(self.dataflow_job_name)
if self.dataflow_job_name and self.dataflow_config.location:
DataflowJobLink.persist(
context=context,
Expand Down Expand Up @@ -798,8 +766,6 @@ def execute(self, context: Context):
is_dataflow,
dataflow_job_name,
snake_case_pipeline_options,
process_line_callback,
_,
) = self._init_pipeline_options(format_pipeline_options=True, job_name_variable_key="job_name")

if not self.beam_hook:
Expand All @@ -821,8 +787,8 @@ def execute(self, context: Context):
go_artifact.start_pipeline(
beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
process_line_callback=process_line_callback,
)
self._resolve_dataflow_job_id(dataflow_job_name)
DataflowJobLink.persist(context=context)
if dataflow_job_name and self.dataflow_config.location:
self.dataflow_hook.wait_for_done(
Expand All @@ -836,7 +802,6 @@ def execute(self, context: Context):
go_artifact.start_pipeline(
beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
process_line_callback=process_line_callback,
)

def on_kill(self) -> None:
Expand All @@ -861,7 +826,6 @@ def start_pipeline(
self,
beam_hook: BeamHook,
variables: dict,
process_line_callback: Callable[[str], None] | None = None,
) -> None: ...


Expand All @@ -881,12 +845,10 @@ def start_pipeline(
self,
beam_hook: BeamHook,
variables: dict,
process_line_callback: Callable[[str], None] | None = None,
) -> None:
beam_hook.start_go_pipeline(
variables=variables,
go_file=self.file,
process_line_callback=process_line_callback,
should_init_module=self.should_init_go_module,
)

Expand Down Expand Up @@ -931,13 +893,11 @@ def start_pipeline(
self,
beam_hook: BeamHook,
variables: dict,
process_line_callback: Callable[[str], None] | None = None,
) -> None:
beam_hook.start_go_pipeline_with_binary(
variables=variables,
launcher_binary=self.launcher,
worker_binary=self.worker,
process_line_callback=process_line_callback,
)


Expand Down
38 changes: 28 additions & 10 deletions providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def test_exec_dataflow_runner(
start_python_dataflow.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
dataflow_hook_mock.return_value.fetch_job_id_by_name.return_value = None
op = BeamRunPythonPipelineOperator(
dataflow_config={"impersonation_chain": TEST_IMPERSONATION_ACCOUNT},
runner="DataflowRunner",
Expand Down Expand Up @@ -261,8 +262,6 @@ def test_exec_dataflow_runner(
py_interpreter=PY_INTERPRETER,
py_requirements=None,
py_system_site_packages=False,
process_line_callback=mock.ANY,
is_dataflow_job_id_exist_callback=mock.ANY,
)

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
Expand All @@ -281,6 +280,31 @@ def test_exec_dataflow_runner__no_dataflow_job_name(
op.execute({})
assert op.dataflow_config.job_name == op.task_id

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner__resolves_job_id_by_name(
self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock
):
"""After the Beam launcher returns, the job id is resolved via DataflowHook.fetch_job_id_by_name."""
resolved_id = "2026-05-28_07_15_42-1234567890"
dataflow_hook_mock.return_value.fetch_job_id_by_name.return_value = resolved_id
op = BeamRunPythonPipelineOperator(
dataflow_config={"impersonation_chain": TEST_IMPERSONATION_ACCOUNT},
runner="DataflowRunner",
**self.default_op_kwargs,
)

op.execute({})

dataflow_hook_mock.return_value.fetch_job_id_by_name.assert_called_once_with(
name=op.dataflow_job_name,
project_id=op.dataflow_config.project_id,
location=op.dataflow_config.location,
)
assert op.dataflow_job_id == resolved_id

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
Expand Down Expand Up @@ -451,6 +475,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
)
gcs_provide_file = gcs_hook.return_value.provide_file
dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
dataflow_hook_mock.return_value.fetch_job_id_by_name.return_value = None

op.execute({})

Expand Down Expand Up @@ -484,8 +509,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
variables=expected_options,
jar=gcs_provide_file.return_value.__enter__.return_value.name,
job_class=JOB_CLASS,
process_line_callback=mock.ANY,
is_dataflow_job_id_exist_callback=mock.ANY,
)

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
Expand Down Expand Up @@ -662,7 +685,6 @@ def test_exec_direct_runner_with_gcs_go_file(self, gcs_hook, beam_hook_mock, _):
start_go_pipeline_method.assert_called_once_with(
variables=expected_options,
go_file=expected_go_file,
process_line_callback=None,
should_init_module=True,
)

Expand Down Expand Up @@ -713,7 +735,6 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str)
variables=expected_options,
launcher_binary=expected_binary,
worker_binary=expected_binary,
process_line_callback=None,
)

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
Expand All @@ -734,7 +755,6 @@ def test_exec_direct_runner_with_local_go_file(self, init_module, beam_hook_mock
start_go_pipeline_method.assert_called_once_with(
variables={"labels": {"airflow-version": TEST_VERSION}},
go_file=local_go_file_path,
process_line_callback=None,
should_init_module=False,
)

Expand All @@ -758,7 +778,6 @@ def test_exec_direct_runner_with_local_launcher_binary(self, mock_beam_hook):
variables={"labels": {"airflow-version": TEST_VERSION}},
launcher_binary=expected_binary,
worker_binary=expected_binary,
process_line_callback=None,
)

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
Expand Down Expand Up @@ -811,7 +830,6 @@ def test_exec_dataflow_runner_with_go_file(
beam_hook_mock.return_value.start_go_pipeline.assert_called_once_with(
variables=expected_options,
go_file=expected_go_file,
process_line_callback=mock.ANY,
should_init_module=True,
)
dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
Expand Down Expand Up @@ -897,7 +915,6 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str)
variables=expected_options,
launcher_binary=expected_launcher_binary,
worker_binary=expected_worker_binary,
process_line_callback=mock.ANY,
)
mock_persist_link.assert_called_once_with(context={})
wait_for_done_method.assert_called_once_with(
Expand Down Expand Up @@ -1049,6 +1066,7 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_on_kill_direct_runner(self, _, dataflow_mock, __):
dataflow_cancel_job = dataflow_mock.return_value.cancel_job
dataflow_mock.return_value.fetch_job_id_by_name.return_value = None
op = BeamRunPythonPipelineOperator(runner="DataflowRunner", **self.default_op_kwargs)
if AIRFLOW_V_3_0_PLUS:
with pytest.raises(TaskDeferred):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,38 @@ def is_job_dataflow_running(
)
return jobs_controller.is_job_running()

@GoogleBaseHook.fallback_to_default_project_id
def fetch_job_id_by_name(
self,
name: str,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> str | None:
"""
Look up a single Dataflow job id by name prefix.

Returns the id when exactly one active job's name starts with ``name``;
``None`` otherwise.
"""
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
name=name,
location=location,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
num_retries=self.num_retries,
cancel_timeout=self.cancel_timeout,
)
try:
jobs = jobs_controller._get_current_jobs()
except Exception:
self.log.warning("Failed to look up Dataflow job id by name %r.", name, exc_info=True)
return None
if len(jobs) != 1:
return None
Comment on lines +1148 to +1149
Copy link
Copy Markdown
Contributor

@MaksYermak MaksYermak May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@evgeniy-b as I understand in case when users run in parallel 2 or more Jobs with the same name or on Dataflow the Job with this name already present than this code returns None as JobID value, please correct me if I am wrong?

In the current logic with callbacks the code parse Apache Beam logs for availability of JobID and when getting it then starts the waiting process in deferrable or non-deferable mode. It means that we always have unique Job ID.

This new logic looks for me as a breaking change because returns None as JobID in case when in Dataflow the users have 2 or more Jobs with the same name. It is possible scenario for the most of our users because in Dataflow is impossible to remove finished Jobs the user can only archived it. And our _fetch_all_jobs method does not sort Jobs by finished or running and returns all Jobs with the same name.

Copy link
Copy Markdown
Contributor Author

@evgeniy-b evgeniy-b May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me explain a bit how I arrived here. On an airflow cluster I maintain I noticed python beam jobs running with deferrable=False, so I switched that flag to true to not waste worker resources. On the next day the jobs failed while transitioning to async triggers because their STDOUT didn't contain the job ID. In the sync mode a missing job ID doesn't prevent the task from succeeding:

_DataflowJobsController.wait_for_done polls self._refresh_jobs():

def wait_for_done(self) -> None:
"""Wait for result of submitted job."""
self.log.info("Start waiting for done.")
self._refresh_jobs()
while self._jobs and not all(
self.job_reached_terminal_state(job, self._wait_until_finished, self._expected_terminal_state)
for job in self._jobs
):
self.log.info("Waiting for done. Sleep %s s", self._poll_sleep)
time.sleep(self._poll_sleep)
self._refresh_jobs()

_refresh_jobs calls self._get_current_jobs():

def _refresh_jobs(self) -> None:
"""
Get all jobs by name.
:return: jobs
"""
self._jobs = self._get_current_jobs()

_get_current_jobs — with no _job_id — calls self._fetch_jobs_by_prefix_name(self._job_name.lower()):

def _get_current_jobs(self) -> list[dict]:
"""
Get list of jobs that start with job name or id.
:return: list of jobs including id's
"""
if not self._multiple_jobs and self._job_id:
return [self.fetch_job_by_id(self._job_id)]
if self._jobs:
return [self.fetch_job_by_id(job["id"]) for job in self._jobs]
if self._job_name:
jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower())

_fetch_jobs_by_prefix_name calls self._fetch_all_jobs() and returns every prefix-matched job (archived + running, no terminal-state filter):

def _fetch_jobs_by_prefix_name(self, prefix_name: str) -> list[dict]:
jobs = self._fetch_all_jobs()
jobs = [job for job in jobs if job["name"].startswith(prefix_name)]
return jobs

So today's sync path already silently picks up every prefix-matched job whenever the regex misses.

With default append_job_name=True the job name will be unique and job ID will be retrieved.
But you are right, it is a degradation: for jobs without unique names but printing out their IDs to console, the job ID will become missing.

I guess an alternative could be to replicate the sync mode's behavior in the async path which currently fails without job_id. However it means that xcom and a link to the job will stay broken.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think job ID in output detection should be reverted. While it is awkward in principle, it is the only way (?) to reliably get ID when job names are not unique. Then name-based ID detection can be used as a fallback but only when append_job_name=True. And if the trigger receives empty job ID it should fallback to polling status of all jobs matching the name (and not in terminal status).
@MaksYermak what's your take on this?

return jobs[0].get("id") or None

@GoogleBaseHook.fallback_to_default_project_id
def cancel_job(
self,
Expand Down
Loading
Loading