diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 458577f95ffda..d69d0d3295728 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -46,6 +46,7 @@ DELETE_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/delete") REPAIR_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/repair") OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output") +CANCEL_ALL_RUNS_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel-all") INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install") UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall") @@ -353,6 +354,15 @@ def cancel_run(self, run_id: int) -> None: json = {"run_id": run_id} self._do_api_call(CANCEL_RUN_ENDPOINT, json) + def cancel_all_runs(self, job_id: int) -> None: + """ + Cancels all active runs of a job. The runs are canceled asynchronously. + + :param job_id: The canonical identifier of the job to cancel all runs of + """ + json = {"job_id": job_id} + self._do_api_call(CANCEL_ALL_RUNS_ENDPOINT, json) + def delete_run(self, run_id: int) -> None: """ Deletes a non-active run. diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 1d6d862363e5a..ec3d9e28c6ad8 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -138,11 +138,18 @@ def get_run_output_endpoint(host): def cancel_run_endpoint(host): """ - Utility function to generate the get run endpoint given the host. + Utility function to generate the cancel run endpoint given the host. """ return f"https://{host}/api/2.1/jobs/runs/cancel" +def cancel_all_runs_endpoint(host): + """ + Utility function to generate the cancel all runs endpoint given the host. + """ + return f"https://{host}/api/2.1/jobs/runs/cancel-all" + + def delete_run_endpoint(host): """ Utility function to generate delete run endpoint given the host. @@ -535,6 +542,21 @@ def test_cancel_run(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_cancel_all_runs(self, mock_requests): + mock_requests.post.return_value.json.return_value = {} + + self.hook.cancel_all_runs(JOB_ID) + + mock_requests.post.assert_called_once_with( + cancel_all_runs_endpoint(HOST), + json={"job_id": JOB_ID}, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_delete_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {}