From 574102fd291930ed45262a40fb7033a122152541 Mon Sep 17 00:00:00 2001 From: gaurav7261 <142777151+gaurav7261@users.noreply.github.com> Date: Thu, 11 Jan 2024 22:24:47 +0530 Subject: [PATCH] [FEAT] adds repair run functionality for databricks (#36601) * [FEAT] adds repair run functionality for databricks * [FIX] addded latest repair run and test cases * [FIX] comma typo * [FIX] check for DatabricksRunNowOperator instance before doing repair run * [FIX] fixed static checks * [FIX] fixed static checks * Update airflow/providers/databricks/hooks/databricks.py Co-authored-by: Andrey Anshin * [FIX] type annotations * [FIX] change from log.warn to log.warning * Update airflow/providers/databricks/operators/databricks.py Co-authored-by: Andrey Anshin * [FIX] CI Static check --------- Co-authored-by: GauravM Co-authored-by: GauravM Co-authored-by: GauravM Co-authored-by: Andrey Anshin --- .../providers/databricks/hooks/databricks.py | 15 +++- .../databricks/operators/databricks.py | 17 +++++ .../databricks/hooks/test_databricks.py | 73 +++++++++++++++++++ 3 files changed, 103 insertions(+), 2 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index b39e3d622c2a7..bc3bd902095ef 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -519,13 +519,24 @@ def delete_run(self, run_id: int) -> None: json = {"run_id": run_id} self._do_api_call(DELETE_RUN_ENDPOINT, json) - def repair_run(self, json: dict) -> None: + def repair_run(self, json: dict) -> int: """ Re-run one or more tasks. :param json: repair a job run. """ - self._do_api_call(REPAIR_RUN_ENDPOINT, json) + response = self._do_api_call(REPAIR_RUN_ENDPOINT, json) + return response["repair_id"] + + def get_latest_repair_id(self, run_id: int) -> int | None: + """Get latest repair id if any exist for run_id else None.""" + json = {"run_id": run_id, "include_history": True} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + repair_history = response["repair_history"] + if len(repair_history) == 1: + return None + else: + return repair_history[-1]["id"] def get_cluster_state(self, cluster_id: str) -> ClusterState: """ diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index edea8b4e59738..5d8b62643fe27 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -88,6 +88,19 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: f"{operator.task_id} failed with terminal state: {run_state} " f"and with the error {run_state.state_message}" ) + if isinstance(operator, DatabricksRunNowOperator) and operator.repair_run: + operator.repair_run = False + log.warning( + "%s but since repair run is set, repairing the run with all failed tasks", + error_message, + ) + + latest_repair_id = hook.get_latest_repair_id(operator.run_id) + repair_json = {"run_id": operator.run_id, "rerun_all_failed_tasks": True} + if latest_repair_id is not None: + repair_json["latest_repair_id"] = latest_repair_id + operator.json["latest_repair_id"] = hook.repair_run(operator, repair_json) + _handle_databricks_operator_execution(operator, hook, log, context) raise AirflowException(error_message) else: @@ -623,6 +636,7 @@ class DatabricksRunNowOperator(BaseOperator): - ``jar_params`` - ``spark_submit_params`` - ``idempotency_token`` + - ``repair_run`` :param job_id: the job_id of the existing Databricks job. This field will be templated. @@ -711,6 +725,7 @@ class DatabricksRunNowOperator(BaseOperator): :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param deferrable: Run operator in the deferrable mode. + :param repair_run: Repair the databricks run in case of failure, doesn't work in deferrable mode """ # Used in airflow.models.BaseOperator @@ -741,6 +756,7 @@ def __init__( do_xcom_push: bool = True, wait_for_termination: bool = True, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + repair_run: bool = False, **kwargs, ) -> None: """Create a new ``DatabricksRunNowOperator``.""" @@ -753,6 +769,7 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable + self.repair_run = repair_run if job_id is not None: self.json["job_id"] = job_id diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 1baaab1fea6cb..c9004e7175fd7 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -683,6 +683,79 @@ def test_repair_run(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_negative_get_latest_repair_id(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = { + "job_id": JOB_ID, + "run_id": RUN_ID, + "state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"}, + "repair_history": [ + { + "type": "ORIGINAL", + "start_time": 1704528798059, + "end_time": 1704529026679, + "state": { + "life_cycle_state": "RUNNING", + "result_state": "RUNNING", + "state_message": "dummy", + "user_cancelled_or_timedout": "false", + }, + "task_run_ids": [396529700633015, 1111270934390307], + } + ], + } + latest_repair_id = self.hook.get_latest_repair_id(RUN_ID) + + assert latest_repair_id is None + + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_positive_get_latest_repair_id(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = { + "job_id": JOB_ID, + "run_id": RUN_ID, + "state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"}, + "repair_history": [ + { + "type": "ORIGINAL", + "start_time": 1704528798059, + "end_time": 1704529026679, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "CANCELED", + "state_message": "dummy_original", + "user_cancelled_or_timedout": "false", + }, + "task_run_ids": [396529700633015, 1111270934390307], + }, + { + "type": "REPAIR", + "start_time": 1704530276423, + "end_time": 1704530363736, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "CANCELED", + "state_message": "dummy_repair_1", + "user_cancelled_or_timedout": "true", + }, + "id": 108607572123234, + "task_run_ids": [396529700633015, 1111270934390307], + }, + { + "type": "REPAIR", + "start_time": 1704531464690, + "end_time": 1704531481590, + "state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"}, + "id": 52532060060836, + "task_run_ids": [396529700633015, 1111270934390307], + }, + ], + } + latest_repair_id = self.hook.get_latest_repair_id(RUN_ID) + + assert latest_repair_id == 52532060060836 + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_get_cluster_state(self, mock_requests): """