Skip to content

Commit

Permalink
[FEAT] adds repair run functionality for databricks (#36601)
Browse files Browse the repository at this point in the history
* [FEAT] adds repair run functionality for databricks

* [FIX] addded latest repair run and test cases

* [FIX] comma typo

* [FIX] check for DatabricksRunNowOperator instance before doing repair run

* [FIX] fixed static checks

* [FIX] fixed static checks

* Update airflow/providers/databricks/hooks/databricks.py

Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>

* [FIX]  type annotations

* [FIX] change from log.warn to log.warning

* Update airflow/providers/databricks/operators/databricks.py

Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>

* [FIX] CI Static check

---------

Co-authored-by: GauravM <gaurav@ip-192-168-0-100.ap-south-1.compute.internal>
Co-authored-by: GauravM <gaurav@ip-192-168-0-101.ap-south-1.compute.internal>
Co-authored-by: GauravM <gaurav@ip-10-20-1-171.ap-south-1.compute.internal>
Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>
  • Loading branch information
5 people committed Jan 11, 2024
1 parent 449c814 commit 574102f
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
15 changes: 13 additions & 2 deletions airflow/providers/databricks/hooks/databricks.py
Expand Up @@ -519,13 +519,24 @@ def delete_run(self, run_id: int) -> None:
json = {"run_id": run_id}
self._do_api_call(DELETE_RUN_ENDPOINT, json)

def repair_run(self, json: dict) -> None:
def repair_run(self, json: dict) -> int:
"""
Re-run one or more tasks.
:param json: repair a job run.
"""
self._do_api_call(REPAIR_RUN_ENDPOINT, json)
response = self._do_api_call(REPAIR_RUN_ENDPOINT, json)
return response["repair_id"]

def get_latest_repair_id(self, run_id: int) -> int | None:
"""Get latest repair id if any exist for run_id else None."""
json = {"run_id": run_id, "include_history": True}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
repair_history = response["repair_history"]
if len(repair_history) == 1:
return None
else:
return repair_history[-1]["id"]

def get_cluster_state(self, cluster_id: str) -> ClusterState:
"""
Expand Down
17 changes: 17 additions & 0 deletions airflow/providers/databricks/operators/databricks.py
Expand Up @@ -88,6 +88,19 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
f"{operator.task_id} failed with terminal state: {run_state} "
f"and with the error {run_state.state_message}"
)
if isinstance(operator, DatabricksRunNowOperator) and operator.repair_run:
operator.repair_run = False
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
error_message,
)

latest_repair_id = hook.get_latest_repair_id(operator.run_id)
repair_json = {"run_id": operator.run_id, "rerun_all_failed_tasks": True}
if latest_repair_id is not None:
repair_json["latest_repair_id"] = latest_repair_id
operator.json["latest_repair_id"] = hook.repair_run(operator, repair_json)
_handle_databricks_operator_execution(operator, hook, log, context)
raise AirflowException(error_message)

else:
Expand Down Expand Up @@ -623,6 +636,7 @@ class DatabricksRunNowOperator(BaseOperator):
- ``jar_params``
- ``spark_submit_params``
- ``idempotency_token``
- ``repair_run``
:param job_id: the job_id of the existing Databricks job.
This field will be templated.
Expand Down Expand Up @@ -711,6 +725,7 @@ class DatabricksRunNowOperator(BaseOperator):
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
:param repair_run: Repair the databricks run in case of failure, doesn't work in deferrable mode
"""

# Used in airflow.models.BaseOperator
Expand Down Expand Up @@ -741,6 +756,7 @@ def __init__(
do_xcom_push: bool = True,
wait_for_termination: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
repair_run: bool = False,
**kwargs,
) -> None:
"""Create a new ``DatabricksRunNowOperator``."""
Expand All @@ -753,6 +769,7 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.repair_run = repair_run

if job_id is not None:
self.json["job_id"] = job_id
Expand Down
73 changes: 73 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Expand Up @@ -683,6 +683,79 @@ def test_repair_run(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_negative_get_latest_repair_id(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = {
"job_id": JOB_ID,
"run_id": RUN_ID,
"state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"},
"repair_history": [
{
"type": "ORIGINAL",
"start_time": 1704528798059,
"end_time": 1704529026679,
"state": {
"life_cycle_state": "RUNNING",
"result_state": "RUNNING",
"state_message": "dummy",
"user_cancelled_or_timedout": "false",
},
"task_run_ids": [396529700633015, 1111270934390307],
}
],
}
latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)

assert latest_repair_id is None

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_positive_get_latest_repair_id(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = {
"job_id": JOB_ID,
"run_id": RUN_ID,
"state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"},
"repair_history": [
{
"type": "ORIGINAL",
"start_time": 1704528798059,
"end_time": 1704529026679,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "CANCELED",
"state_message": "dummy_original",
"user_cancelled_or_timedout": "false",
},
"task_run_ids": [396529700633015, 1111270934390307],
},
{
"type": "REPAIR",
"start_time": 1704530276423,
"end_time": 1704530363736,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "CANCELED",
"state_message": "dummy_repair_1",
"user_cancelled_or_timedout": "true",
},
"id": 108607572123234,
"task_run_ids": [396529700633015, 1111270934390307],
},
{
"type": "REPAIR",
"start_time": 1704531464690,
"end_time": 1704531481590,
"state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"},
"id": 52532060060836,
"task_run_ids": [396529700633015, 1111270934390307],
},
],
}
latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)

assert latest_repair_id == 52532060060836

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_get_cluster_state(self, mock_requests):
"""
Expand Down

0 comments on commit 574102f

Please sign in to comment.