From ce3387ae5471d6037e2192c0f2fd2dd33d8fe250 Mon Sep 17 00:00:00 2001 From: raphaelauv Date: Mon, 22 Apr 2024 23:08:24 +0200 Subject: [PATCH] review_1 --- airflow/operators/trigger_dagrun.py | 20 +++++++++++++------- tests/operators/test_trigger_dagrun.py | 19 ++++++++++++++++++- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index f6dc266a36fbd..b6c24f2180fbf 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -28,7 +28,13 @@ from airflow.api.common.trigger_dag import trigger_dag from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException, DagNotFound, DagRunAlreadyExists, RemovedInAirflow3Warning +from airflow.exceptions import ( + AirflowException, + AirflowSkipException, + DagNotFound, + DagRunAlreadyExists, + RemovedInAirflow3Warning, +) from airflow.models.baseoperator import BaseOperator from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DagModel @@ -90,7 +96,7 @@ class TriggerDagRunOperator(BaseOperator): (default: 60) :param allowed_states: List of allowed states, default is ``['success']``. :param failed_states: List of failed or dis-allowed states, default is ``None``. - :param soft_fail: Set to true to mark the task as SKIPPED on DagRunAlreadyExists + :param skip_when_already_exists: Set to true to mark the task as SKIPPED if a dag_run already exists :param deferrable: If waiting for completion, whether or not to defer the task until done, default is ``False``. :param execution_date: Deprecated parameter; same as ``logical_date``. @@ -102,7 +108,7 @@ class TriggerDagRunOperator(BaseOperator): "logical_date", "conf", "wait_for_completion", - "soft_fail", + "skip_when_already_exists", ) template_fields_renderers = {"conf": "py"} ui_color = "#ffefeb" @@ -120,7 +126,7 @@ def __init__( poke_interval: int = 60, allowed_states: list[str] | None = None, failed_states: list[str] | None = None, - soft_fail: bool = False, + skip_when_already_exists: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), execution_date: str | datetime.datetime | None = None, **kwargs, @@ -140,7 +146,7 @@ def __init__( self.failed_states = [DagRunState(s) for s in failed_states] else: self.failed_states = [DagRunState.FAILED] - self.soft_fail = soft_fail + self.skip_when_already_exists = skip_when_already_exists self._defer = deferrable if execution_date is not None: @@ -200,9 +206,9 @@ def execute(self, context: Context): dag_run = e.dag_run dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date) else: - if self.soft_fail: + if self.skip_when_already_exists: raise AirflowSkipException( - "Skipping due to soft_fail is set to True and DagRunAlreadyExists" + "Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists" ) raise e if dag_run is None: diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 9eed9b786ea9f..522eacf493994 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -35,7 +35,7 @@ from airflow.triggers.external_task import DagStateTrigger from airflow.utils import timezone from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState from airflow.utils.types import DagRunType pytestmark = pytest.mark.db_test @@ -322,6 +322,23 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail(self, trigger_run_id, trig with pytest.raises(DagRunAlreadyExists): task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) + def test_trigger_dagrun_with_skip_when_already_exists(self): + """Test TriggerDagRunOperator with skip_when_already_exists.""" + execution_date = DEFAULT_DATE + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="dummy_run_id", + execution_date=None, + reset_dag_run=False, + skip_when_already_exists=True, + dag=self.dag, + ) + task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) + assert task.get_task_instances()[0].state == TaskInstanceState.SUCCESS + task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) + assert task.get_task_instances()[0].state == TaskInstanceState.SKIPPED + @pytest.mark.parametrize( "trigger_run_id, trigger_logical_date, expected_dagruns_count", [