Skip to content

Commit

Permalink
Update list_jobs function in DatabricksHook to token-based pagina…
Browse files Browse the repository at this point in the history
…tion (#33472)
  • Loading branch information
oleksiidav committed Sep 12, 2023
1 parent 401e7bd commit dfec053
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
30 changes: 27 additions & 3 deletions airflow/providers/databricks/hooks/databricks.py
Expand Up @@ -28,11 +28,12 @@
from __future__ import annotations

import json
import warnings
from typing import Any

from requests import exceptions as requests_exceptions

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
Expand Down Expand Up @@ -163,7 +164,12 @@ def submit_run(self, json: dict) -> int:
return response["run_id"]

def list_jobs(
self, limit: int = 25, offset: int = 0, expand_tasks: bool = False, job_name: str | None = None
self,
limit: int = 25,
offset: int | None = None,
expand_tasks: bool = False,
job_name: str | None = None,
page_token: str | None = None,
) -> list[dict[str, Any]]:
"""
Lists the jobs in the Databricks Job Service.
Expand All @@ -172,17 +178,34 @@ def list_jobs(
:param offset: The offset of the first job to return, relative to the most recently created job.
:param expand_tasks: Whether to include task and cluster details in the response.
:param job_name: Optional name of a job to search.
:param page_token: The optional page token pointing at the first first job to return.
:return: A list of jobs.
"""
has_more = True
all_jobs = []
use_token_pagination = (page_token is not None) or (offset is None)
if offset is not None:
warnings.warn(
"""You are using the deprecated offset parameter in list_jobs.
It will be hard-limited at the maximum value of 1000 by Databricks API after Oct 9, 2023.
Please paginate using page_token instead.""",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if page_token is None:
page_token = ""
if offset is None:
offset = 0

while has_more:
payload: dict[str, Any] = {
"limit": limit,
"expand_tasks": expand_tasks,
"offset": offset,
}
if use_token_pagination:
payload["page_token"] = page_token
else: # offset pagination
payload["offset"] = offset
if job_name:
payload["name"] = job_name
response = self._do_api_call(LIST_JOBS_ENDPOINT, payload)
Expand All @@ -193,6 +216,7 @@ def list_jobs(
all_jobs += jobs
has_more = response.get("has_more", False)
if has_more:
page_token = response.get("next_page_token", "")
offset += len(jobs)

return all_jobs
Expand Down
20 changes: 13 additions & 7 deletions tests/providers/databricks/hooks/test_databricks.py
Expand Up @@ -737,7 +737,7 @@ def test_list_jobs_success_single_page(self, mock_requests):
mock_requests.get.assert_called_once_with(
list_jobs_endpoint(HOST),
json=None,
params={"limit": 25, "offset": 0, "expand_tasks": False},
params={"limit": 25, "page_token": "", "expand_tasks": False},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand All @@ -749,7 +749,9 @@ def test_list_jobs_success_single_page(self, mock_requests):
def test_list_jobs_success_multiple_pages(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.side_effect = [
create_successful_response_mock({**LIST_JOBS_RESPONSE, "has_more": True}),
create_successful_response_mock(
{**LIST_JOBS_RESPONSE, "has_more": True, "next_page_token": "PAGETOKEN"}
),
create_successful_response_mock(LIST_JOBS_RESPONSE),
]

Expand All @@ -759,11 +761,15 @@ def test_list_jobs_success_multiple_pages(self, mock_requests):

first_call_args = mock_requests.method_calls[0]
assert first_call_args[1][0] == list_jobs_endpoint(HOST)
assert first_call_args[2]["params"] == {"limit": 25, "offset": 0, "expand_tasks": False}
assert first_call_args[2]["params"] == {"limit": 25, "page_token": "", "expand_tasks": False}

second_call_args = mock_requests.method_calls[1]
assert second_call_args[1][0] == list_jobs_endpoint(HOST)
assert second_call_args[2]["params"] == {"limit": 25, "offset": 1, "expand_tasks": False}
assert second_call_args[2]["params"] == {
"limit": 25,
"page_token": "PAGETOKEN",
"expand_tasks": False,
}

assert len(jobs) == 2
assert jobs == LIST_JOBS_RESPONSE["jobs"] * 2
Expand All @@ -778,7 +784,7 @@ def test_get_job_id_by_name_success(self, mock_requests):
mock_requests.get.assert_called_once_with(
list_jobs_endpoint(HOST),
json=None,
params={"limit": 25, "offset": 0, "expand_tasks": False, "name": JOB_NAME},
params={"limit": 25, "page_token": "", "expand_tasks": False, "name": JOB_NAME},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand All @@ -797,7 +803,7 @@ def test_get_job_id_by_name_not_found(self, mock_requests):
mock_requests.get.assert_called_once_with(
list_jobs_endpoint(HOST),
json=None,
params={"limit": 25, "offset": 0, "expand_tasks": False, "name": job_name},
params={"limit": 25, "page_token": "", "expand_tasks": False, "name": job_name},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand All @@ -820,7 +826,7 @@ def test_get_job_id_by_name_raise_exception_with_duplicates(self, mock_requests)
mock_requests.get.assert_called_once_with(
list_jobs_endpoint(HOST),
json=None,
params={"limit": 25, "offset": 0, "expand_tasks": False, "name": JOB_NAME},
params={"limit": 25, "page_token": "", "expand_tasks": False, "name": JOB_NAME},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand Down

0 comments on commit dfec053

Please sign in to comment.