Skip to content

Commit

Permalink
fix typos in DatabricksSubmitRunOperator (#36248)
Browse files Browse the repository at this point in the history
* fix typos in DatabricksSubmitRunOperator

* databricks find_pipeline_id_by_name tests and fixes
  • Loading branch information
adam133 committed Dec 21, 2023
1 parent e2393ee commit 322aa64
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 6 deletions.
10 changes: 5 additions & 5 deletions airflow/providers/databricks/hooks/databricks.py
Expand Up @@ -55,7 +55,7 @@
UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")

LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list")
LIST_PIPELINES_ENDPOINT = ("GET", "/api/2.0/pipelines")
LIST_PIPELINES_ENDPOINT = ("GET", "api/2.0/pipelines")

WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")

Expand Down Expand Up @@ -322,8 +322,8 @@ def list_pipelines(
payload["filter"] = filter

while has_more:
if next_token:
payload["page_token"] = next_token
if next_token is not None:
payload = {**payload, "page_token": next_token}
response = self._do_api_call(LIST_PIPELINES_ENDPOINT, payload)
pipelines = response.get("statuses", [])
all_pipelines += pipelines
Expand All @@ -345,11 +345,11 @@ def find_pipeline_id_by_name(self, pipeline_name: str) -> str | None:

if len(matching_pipelines) > 1:
raise AirflowException(
f"There are more than one job with name {pipeline_name}. "
f"There are more than one pipelines with name {pipeline_name}. "
"Please delete duplicated pipelines first"
)

if not pipeline_name:
if not pipeline_name or len(matching_pipelines) == 0:
return None
else:
return matching_pipelines[0]["pipeline_id"]
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/operators/databricks.py
Expand Up @@ -521,7 +521,7 @@ def execute(self, context: Context):
):
# If pipeline_id is not provided, we need to fetch it from the pipeline_name
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] = self._hook.get_pipeline_id(pipeline_name)
self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
Expand Down
108 changes: 108 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Expand Up @@ -59,6 +59,8 @@
RUN_ID = 1
JOB_ID = 42
JOB_NAME = "job-name"
PIPELINE_NAME = "some pipeline name"
PIPELINE_ID = "its-a-pipeline-id"
DEFAULT_RETRY_NUMBER = 3
DEFAULT_RETRY_ARGS = dict(
wait=tenacity.wait_none(),
Expand Down Expand Up @@ -100,6 +102,19 @@
],
"has_more": False,
}
LIST_PIPELINES_RESPONSE = {
"statuses": [
{
"pipeline_id": PIPELINE_ID,
"state": "DEPLOYING",
"cluster_id": "string",
"name": PIPELINE_NAME,
"latest_updates": [{"update_id": "string", "state": "QUEUED", "creation_time": "string"}],
"creator_user_name": "string",
"run_as_user_name": "string",
}
]
}
LIST_SPARK_VERSIONS_RESPONSE = {
"versions": [
{"key": "8.2.x-scala2.12", "name": "8.2 (includes Apache Spark 3.1.1, Scala 2.12)"},
Expand Down Expand Up @@ -226,6 +241,13 @@ def list_jobs_endpoint(host):
return f"https://{host}/api/2.1/jobs/list"


def list_pipelines_endpoint(host):
"""
Utility function to generate the list jobs endpoint given the host
"""
return f"https://{host}/api/2.0/pipelines"


def list_spark_versions_endpoint(host):
"""Utility function to generate the list spark versions endpoint given the host"""
return f"https://{host}/api/2.0/clusters/spark-versions"
Expand Down Expand Up @@ -915,6 +937,92 @@ def test_get_job_id_by_name_raise_exception_with_duplicates(self, mock_requests)
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_get_pipeline_id_by_name_success(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = LIST_PIPELINES_RESPONSE

pipeline_id = self.hook.find_pipeline_id_by_name(PIPELINE_NAME)

mock_requests.get.assert_called_once_with(
list_pipelines_endpoint(HOST),
json=None,
params={"filter": f"name LIKE '{PIPELINE_NAME}'", "max_results": 25},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

assert pipeline_id == PIPELINE_ID

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_list_pipelines_success_multiple_pages(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.side_effect = [
create_successful_response_mock({**LIST_PIPELINES_RESPONSE, "next_page_token": "PAGETOKEN"}),
create_successful_response_mock(LIST_PIPELINES_RESPONSE),
]

pipelines = self.hook.list_pipelines(pipeline_name=PIPELINE_NAME)

assert mock_requests.get.call_count == 2

first_call_args = mock_requests.method_calls[0]
assert first_call_args[1][0] == list_pipelines_endpoint(HOST)
assert first_call_args[2]["params"] == {"filter": f"name LIKE '{PIPELINE_NAME}'", "max_results": 25}

second_call_args = mock_requests.method_calls[1]
assert second_call_args[1][0] == list_pipelines_endpoint(HOST)
assert second_call_args[2]["params"] == {
"filter": f"name LIKE '{PIPELINE_NAME}'",
"max_results": 25,
"page_token": "PAGETOKEN",
}

assert len(pipelines) == 2
assert pipelines == LIST_PIPELINES_RESPONSE["statuses"] * 2

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_get_pipeline_id_by_name_not_found(self, mock_requests):
empty_response = {"statuses": []}
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = empty_response

ne_pipeline_name = "Non existing pipeline"
pipeline_id = self.hook.find_pipeline_id_by_name(ne_pipeline_name)

mock_requests.get.assert_called_once_with(
list_pipelines_endpoint(HOST),
json=None,
params={"filter": f"name LIKE '{ne_pipeline_name}'", "max_results": 25},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

assert pipeline_id is None

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_list_pipelines_raise_exception_with_duplicates(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = {
**LIST_PIPELINES_RESPONSE,
"statuses": LIST_PIPELINES_RESPONSE["statuses"] * 2,
}

exception_message = f"There are more than one pipelines with name {PIPELINE_NAME}."
with pytest.raises(AirflowException, match=exception_message):
self.hook.find_pipeline_id_by_name(pipeline_name=PIPELINE_NAME)

mock_requests.get.assert_called_once_with(
list_pipelines_endpoint(HOST),
json=None,
params={"filter": f"name LIKE '{PIPELINE_NAME}'", "max_results": 25},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_connection_success(self, mock_requests):
mock_requests.codes.ok = 200
Expand Down
29 changes: 29 additions & 0 deletions tests/providers/databricks/operators/test_databricks.py
Expand Up @@ -758,6 +758,35 @@ def test_exec_success(self, db_mock_class):
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_pipeline_name(self, db_mock_class):
"""
Test the execute function when provided a pipeline name.
"""
run = {"pipeline_task": {"pipeline_name": "This is a test pipeline"}}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.find_pipeline_id_by_name.return_value = PIPELINE_ID_TASK["pipeline_id"]
db_mock.submit_run.return_value = 1
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

expected = utils.normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksSubmitRunOperator",
)
db_mock.find_pipeline_id_by_name.assert_called_once_with("This is a test pipeline")

db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_failure(self, db_mock_class):
"""
Expand Down

0 comments on commit 322aa64

Please sign in to comment.