Skip to content

Commit

Permalink
Add DatabricksNotebookOperator (#39178)
Browse files Browse the repository at this point in the history
Co-authored-by: Tatiana Al-Chueyr <tatiana.alchueyr@gmail.com>
Co-authored-by: Wei Lee <weilee.rx@gmail.com>
  • Loading branch information
3 people committed Apr 26, 2024
1 parent bea1b7f commit 7683344
Show file tree
Hide file tree
Showing 5 changed files with 416 additions and 1 deletion.
174 changes: 174 additions & 0 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,177 @@ class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):

def __init__(self, *args, **kwargs):
super().__init__(deferrable=True, *args, **kwargs)


class DatabricksNotebookOperator(BaseOperator):
"""
Runs a notebook on Databricks using an Airflow operator.
The DatabricksNotebookOperator allows users to launch and monitor notebook
job runs on Databricks as Airflow tasks.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DatabricksNotebookOperator`
:param notebook_path: The path to the notebook in Databricks.
:param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved
from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository
defined in git_source. If the value is empty, the task will use GIT if git_source is defined
and WORKSPACE otherwise. For more information please visit
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
:param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task.
:param notebook_packages: A list of the Python libraries to be installed on the cluster running the
notebook.
:param new_cluster: Specs for a new cluster on which this task will be run.
:param existing_cluster_id: ID for existing cluster on which to run this task.
:param job_cluster_key: The key for the job cluster.
:param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
:param databricks_retry_delay: Number of seconds to wait between retries.
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param databricks_conn_id: The name of the Airflow connection to use.
"""

template_fields = ("notebook_params",)

def __init__(
self,
notebook_path: str,
source: str,
notebook_params: dict | None = None,
notebook_packages: list[dict[str, Any]] | None = None,
new_cluster: dict[str, Any] | None = None,
existing_cluster_id: str = "",
job_cluster_key: str = "",
polling_period_seconds: int = 5,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
**kwargs: Any,
):
self.notebook_path = notebook_path
self.source = source
self.notebook_params = notebook_params or {}
self.notebook_packages = notebook_packages or []
self.new_cluster = new_cluster or {}
self.existing_cluster_id = existing_cluster_id
self.job_cluster_key = job_cluster_key
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id: int | None = None
super().__init__(**kwargs)

@cached_property
def _hook(self) -> DatabricksHook:
return self._get_hook(caller="DatabricksNotebookOperator")

def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)

def _get_task_timeout_seconds(self) -> int:
"""
Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task.
By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when
execution_timeout is not defined, the task continues to run indefinitely. Therefore,
to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating
that the job should run indefinitely. This aligns with the default behavior of Databricks jobs,
where a timeout seconds value of 0 signifies an indefinite run duration.
More details can be found in the Databricks documentation:
See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds
"""
if self.execution_timeout is None:
return 0
execution_timeout_seconds = int(self.execution_timeout.total_seconds())
if execution_timeout_seconds == 0:
raise ValueError(
"If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to "
"`None` if you desire the task to run indefinitely."
)
return execution_timeout_seconds

def _get_task_base_json(self) -> dict[str, Any]:
"""Get task base json to be used for task submissions."""
return {
"timeout_seconds": self._get_task_timeout_seconds(),
"email_notifications": {},
"notebook_task": {
"notebook_path": self.notebook_path,
"source": self.source,
"base_parameters": self.notebook_params,
},
"libraries": self.notebook_packages,
}

def _get_databricks_task_id(self, task_id: str) -> str:
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
return f"{self.dag_id}__{task_id.replace('.', '__')}"

def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
"run_name": self._get_databricks_task_id(self.task_id),
**self._get_task_base_json(),
}
if self.new_cluster and self.existing_cluster_id:
raise ValueError("Both new_cluster and existing_cluster_id are set. Only one should be set.")
if self.new_cluster:
run_json["new_cluster"] = self.new_cluster
elif self.existing_cluster_id:
run_json["existing_cluster_id"] = self.existing_cluster_id
else:
raise ValueError("Must specify either existing_cluster_id or new_cluster.")
return run_json

def launch_notebook_job(self) -> int:
run_json = self._get_run_json()
self.databricks_run_id = self._hook.submit_run(run_json)
url = self._hook.get_run_page_url(self.databricks_run_id)
self.log.info("Check the job run in Databricks: %s", url)
return self.databricks_run_id

def monitor_databricks_job(self) -> None:
if self.databricks_run_id is None:
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info("Current state of the job: %s", run_state.life_cycle_state)
while not run_state.is_terminal:
time.sleep(self.polling_period_seconds)
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info(
"task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state
)
self.log.info("Current state of the job: %s", run_state.life_cycle_state)
if run_state.life_cycle_state != "TERMINATED":
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. "
f"Message: {run_state.state_message}"
)
if not run_state.is_successful:
raise AirflowException(
"Task failed. Final state %s. Reason: %s",
run_state.result_state,
run_state.state_message,
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)

def execute(self, context: Context) -> None:
self.launch_notebook_job()
if self.wait_for_termination:
self.monitor_databricks_job()
1 change: 1 addition & 0 deletions airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ integrations:
external-doc-url: https://databricks.com/
how-to-guide:
- /docs/apache-airflow-providers-databricks/operators/jobs_create.rst
- /docs/apache-airflow-providers-databricks/operators/notebook.rst
- /docs/apache-airflow-providers-databricks/operators/submit_run.rst
- /docs/apache-airflow-providers-databricks/operators/run_now.rst
logo: /integration-logos/databricks/Databricks.png
Expand Down
44 changes: 44 additions & 0 deletions docs/apache-airflow-providers-databricks/operators/notebook.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
.. _howto/operator:DatabricksNotebookOperator:


DatabricksNotebookOperator
==========================

Use the :class:`~airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator` to launch and monitor
notebook job runs on Databricks as Airflow tasks.



Examples
--------

Running a notebook in Databricks on a new cluster
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_notebook_new_cluster]
:end-before: [END howto_operator_databricks_notebook_new_cluster]

Running a notebook in Databricks on an existing cluster
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_notebook_existing_cluster]
:end-before: [END howto_operator_databricks_notebook_existing_cluster]
144 changes: 143 additions & 1 deletion tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta
from unittest import mock
from unittest.mock import MagicMock

Expand All @@ -28,6 +28,7 @@
from airflow.providers.databricks.hooks.databricks import RunState
from airflow.providers.databricks.operators.databricks import (
DatabricksCreateJobsOperator,
DatabricksNotebookOperator,
DatabricksRunNowDeferrableOperator,
DatabricksRunNowOperator,
DatabricksSubmitRunDeferrableOperator,
Expand Down Expand Up @@ -1754,3 +1755,144 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
assert op.run_id == RUN_ID
assert not mock_defer.called


class TestDatabricksNotebookOperator:
def test_execute_with_wait_for_termination(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
)
operator.launch_notebook_job = MagicMock(return_value=12345)
operator.monitor_databricks_job = MagicMock()

operator.execute({})

assert operator.wait_for_termination is True
operator.launch_notebook_job.assert_called_once()
operator.monitor_databricks_job.assert_called_once()

def test_execute_without_wait_for_termination(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
wait_for_termination=False,
)
operator.launch_notebook_job = MagicMock(return_value=12345)
operator.monitor_databricks_job = MagicMock()

operator.execute({})

assert operator.wait_for_termination is False
operator.launch_notebook_job.assert_called_once()
operator.monitor_databricks_job.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
"state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}
}

operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
)

operator.databricks_run_id = 12345
operator.monitor_databricks_job()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_monitor_databricks_job_failed(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
"state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE"}
}

operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
)

operator.databricks_run_id = 12345

exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'"
with pytest.raises(AirflowException) as exc_info:
operator.monitor_databricks_job()
assert exception_message in str(exc_info.value)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_launch_notebook_job(self, mock_databricks_hook):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
existing_cluster_id="test_cluster_id",
)
operator._hook.submit_run.return_value = 12345

run_id = operator.launch_notebook_job()

assert run_id == 12345

def test_both_new_and_existing_cluster_set(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
new_cluster={"new_cluster_config_key": "new_cluster_config_value"},
existing_cluster_id="existing_cluster_id",
databricks_conn_id="test_conn_id",
)
with pytest.raises(ValueError) as exc_info:
operator._get_run_json()
exception_message = "Both new_cluster and existing_cluster_id are set. Only one should be set."
assert str(exc_info.value) == exception_message

def test_both_new_and_existing_cluster_unset(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
)
with pytest.raises(ValueError) as exc_info:
operator._get_run_json()
exception_message = "Must specify either existing_cluster_id or new_cluster."
assert str(exc_info.value) == exception_message

def test_job_runs_forever_by_default(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
existing_cluster_id="existing_cluster_id",
)
run_json = operator._get_run_json()
assert operator.execution_timeout is None
assert run_json["timeout_seconds"] == 0

def test_zero_execution_timeout_raises_error(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
existing_cluster_id="existing_cluster_id",
execution_timeout=timedelta(seconds=0),
)
with pytest.raises(ValueError) as exc_info:
operator._get_run_json()
exception_message = (
"If you've set an `execution_timeout` for the task, ensure it's not `0`. "
"Set it instead to `None` if you desire the task to run indefinitely."
)
assert str(exc_info.value) == exception_message

0 comments on commit 7683344

Please sign in to comment.