From eb48911817b100615a1c138de3a71fbbf4bd580a Mon Sep 17 00:00:00 2001 From: raphaelauv Date: Mon, 6 May 2024 10:00:56 +0200 Subject: [PATCH] feat: soft_fail TriggerDagRunOperator (#39173) * feat: soft_fail TriggerDagRunOperator * review_1 --------- Co-authored-by: raphaelauv --- airflow/operators/trigger_dagrun.py | 16 +++++++++++++++- tests/operators/test_trigger_dagrun.py | 19 ++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index f8cfa5256a57b..b6c24f2180fbf 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -28,7 +28,13 @@ from airflow.api.common.trigger_dag import trigger_dag from airflow.configuration import conf -from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists, RemovedInAirflow3Warning +from airflow.exceptions import ( + AirflowException, + AirflowSkipException, + DagNotFound, + DagRunAlreadyExists, + RemovedInAirflow3Warning, +) from airflow.models.baseoperator import BaseOperator from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DagModel @@ -90,6 +96,7 @@ class TriggerDagRunOperator(BaseOperator): (default: 60) :param allowed_states: List of allowed states, default is ``['success']``. :param failed_states: List of failed or dis-allowed states, default is ``None``. + :param skip_when_already_exists: Set to true to mark the task as SKIPPED if a dag_run already exists :param deferrable: If waiting for completion, whether or not to defer the task until done, default is ``False``. :param execution_date: Deprecated parameter; same as ``logical_date``. @@ -101,6 +108,7 @@ class TriggerDagRunOperator(BaseOperator): "logical_date", "conf", "wait_for_completion", + "skip_when_already_exists", ) template_fields_renderers = {"conf": "py"} ui_color = "#ffefeb" @@ -118,6 +126,7 @@ def __init__( poke_interval: int = 60, allowed_states: list[str] | None = None, failed_states: list[str] | None = None, + skip_when_already_exists: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), execution_date: str | datetime.datetime | None = None, **kwargs, @@ -137,6 +146,7 @@ def __init__( self.failed_states = [DagRunState(s) for s in failed_states] else: self.failed_states = [DagRunState.FAILED] + self.skip_when_already_exists = skip_when_already_exists self._defer = deferrable if execution_date is not None: @@ -196,6 +206,10 @@ def execute(self, context: Context): dag_run = e.dag_run dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date) else: + if self.skip_when_already_exists: + raise AirflowSkipException( + "Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists" + ) raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 9eed9b786ea9f..522eacf493994 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -35,7 +35,7 @@ from airflow.triggers.external_task import DagStateTrigger from airflow.utils import timezone from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState from airflow.utils.types import DagRunType pytestmark = pytest.mark.db_test @@ -322,6 +322,23 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail(self, trigger_run_id, trig with pytest.raises(DagRunAlreadyExists): task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) + def test_trigger_dagrun_with_skip_when_already_exists(self): + """Test TriggerDagRunOperator with skip_when_already_exists.""" + execution_date = DEFAULT_DATE + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="dummy_run_id", + execution_date=None, + reset_dag_run=False, + skip_when_already_exists=True, + dag=self.dag, + ) + task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) + assert task.get_task_instances()[0].state == TaskInstanceState.SUCCESS + task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) + assert task.get_task_instances()[0].state == TaskInstanceState.SKIPPED + @pytest.mark.parametrize( "trigger_run_id, trigger_logical_date, expected_dagruns_count", [