Skip to content

Commit

Permalink
Add operator to create jobs in Databricks (#35156)
Browse files Browse the repository at this point in the history
* Provider Databricks add jobs create operator.

* run black formatter with breeze

* added support for databricks sdk to use the latest set of objects for type hints

* remove without precommit

* added databricks-sdk with precommit

* use the databricks sdk objects

* fixed type hints and adjusted tests

* fixed as dict

* fixed tests with proper testing logic

* added jobs_create to provider.yaml file

* resoved comments on pr

* fixed imports in test_databricks.py

* added correct type hint for reset_job

* change type hint for json arg in DatabricksCreateJobsOperator

* fixed CI errors

* fixed broken tests and imports. also pinned databricks sdk to a specific version ==0.10.0

* fixed broken tests and imports. also pinned databricks sdk to a specific version ==0.10.0

* Fix CI static checks

* Remove databricks-sdk dependency

This was agreed with @stikkireddy, since there the SDK interfaces are changing ATM. When it becomes stable, we can re-introduce this dependency

---------

Co-authored-by: Kyle Winkelman <kyle.winkelman@optum.com>
Co-authored-by: Sri Tikkireddy <sri.tikkireddy@databricks.com>
Co-authored-by: stikkireddy <54602805+stikkireddy@users.noreply.github.com>
  • Loading branch information
4 people committed Oct 27, 2023
1 parent da2fdbb commit a8784e3
Show file tree
Hide file tree
Showing 7 changed files with 776 additions and 3 deletions.
20 changes: 20 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Expand Up @@ -41,6 +41,8 @@
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")

CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create")
RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset")
RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now")
SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit")
GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get")
Expand Down Expand Up @@ -194,6 +196,24 @@ def __init__(
) -> None:
super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller)

def create_job(self, json: dict) -> int:
"""
Utility function to call the ``api/2.1/jobs/create`` endpoint.
:param json: The data used in the body of the request to the ``create`` endpoint.
:return: the job_id as an int
"""
response = self._do_api_call(CREATE_ENDPOINT, json)
return response["job_id"]

def reset_job(self, job_id: str, json: dict) -> None:
"""
Utility function to call the ``api/2.1/jobs/reset`` endpoint.
:param json: The data used in the new_settings of the request to the ``reset`` endpoint.
"""
self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json})

def run_now(self, json: dict) -> int:
"""
Call the ``api/2.1/jobs/run-now`` endpoint.
Expand Down
132 changes: 130 additions & 2 deletions airflow/providers/databricks/operators/databricks.py
Expand Up @@ -21,6 +21,7 @@
import time
import warnings
from functools import cached_property
from logging import Logger
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
Expand All @@ -31,8 +32,6 @@
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from logging import Logger

from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context

Expand Down Expand Up @@ -162,6 +161,135 @@ def get_link(
return XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key)


class DatabricksCreateJobsOperator(BaseOperator):
"""Creates (or resets) a Databricks job using the API endpoint.
.. seealso::
https://docs.databricks.com/api/workspace/jobs/create
https://docs.databricks.com/api/workspace/jobs/reset
:param json: A JSON object containing API parameters which will be passed
directly to the ``api/2.1/jobs/create`` endpoint. The other named parameters
(i.e. ``name``, ``tags``, ``tasks``, etc.) to this operator will
be merged with this json dictionary if they are provided.
If there are conflicts during the merge, the named parameters will
take precedence and override the top level json keys. (templated)
.. seealso::
For more information about templating see :ref:`concepts:jinja-templating`.
:param name: An optional name for the job.
:param tags: A map of tags associated with the job.
:param tasks: A list of task specifications to be executed by this job.
Array of objects (JobTaskSettings).
:param job_clusters: A list of job cluster specifications that can be shared and reused by
tasks of this job. Array of objects (JobCluster).
:param email_notifications: Object (JobEmailNotifications).
:param webhook_notifications: Object (WebhookNotifications).
:param timeout_seconds: An optional timeout applied to each run of this job.
:param schedule: Object (CronSchedule).
:param max_concurrent_runs: An optional maximum allowed number of concurrent runs of the job.
:param git_source: An optional specification for a remote repository containing the notebooks
used by this job's notebook tasks. Object (GitSource).
:param access_control_list: List of permissions to set on the job. Array of object
(AccessControlRequestForUser) or object (AccessControlRequestForGroup) or object
(AccessControlRequestForServicePrincipal).
.. seealso::
This will only be used on create. In order to reset ACL consider using the Databricks
UI.
:param databricks_conn_id: Reference to the
:ref:`Databricks connection <howto/connection:databricks>`. (templated)
:param polling_period_seconds: Controls the rate which we poll for the result of
this run. By default the operator will poll every 30 seconds.
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"

def __init__(
self,
*,
json: Any | None = None,
name: str | None = None,
tags: dict[str, str] | None = None,
tasks: list[dict] | None = None,
job_clusters: list[dict] | None = None,
email_notifications: dict | None = None,
webhook_notifications: dict | None = None,
timeout_seconds: int | None = None,
schedule: dict | None = None,
max_concurrent_runs: int | None = None,
git_source: dict | None = None,
access_control_list: list[dict] | None = None,
databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[Any, Any] | None = None,
**kwargs,
) -> None:
"""Creates a new ``DatabricksCreateJobsOperator``."""
super().__init__(**kwargs)
self.json = json or {}
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
if name is not None:
self.json["name"] = name
if tags is not None:
self.json["tags"] = tags
if tasks is not None:
self.json["tasks"] = tasks
if job_clusters is not None:
self.json["job_clusters"] = job_clusters
if email_notifications is not None:
self.json["email_notifications"] = email_notifications
if webhook_notifications is not None:
self.json["webhook_notifications"] = webhook_notifications
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if schedule is not None:
self.json["schedule"] = schedule
if max_concurrent_runs is not None:
self.json["max_concurrent_runs"] = max_concurrent_runs
if git_source is not None:
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list

self.json = normalise_json_content(self.json)

@cached_property
def _hook(self):
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller="DatabricksCreateJobsOperator",
)

def execute(self, context: Context) -> int:
if "name" not in self.json:
raise AirflowException("Missing required parameter: name")
job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
return job_id


class DatabricksSubmitRunOperator(BaseOperator):
"""
Submits a Spark job run to Databricks using the api/2.1/jobs/runs/submit API endpoint.
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/databricks/provider.yaml
Expand Up @@ -66,6 +66,7 @@ integrations:
- integration-name: Databricks
external-doc-url: https://databricks.com/
how-to-guide:
- /docs/apache-airflow-providers-databricks/operators/jobs_create.rst
- /docs/apache-airflow-providers-databricks/operators/submit_run.rst
- /docs/apache-airflow-providers-databricks/operators/run_now.rst
logo: /integration-logos/databricks/Databricks.png
Expand Down Expand Up @@ -123,3 +124,10 @@ connection-types:

extra-links:
- airflow.providers.databricks.operators.databricks.DatabricksJobRunLink

additional-extras:
# pip install apache-airflow-providers-databricks[sdk]
- name: sdk
description: Install Databricks SDK
dependencies:
- databricks-sdk==0.10.0
91 changes: 91 additions & 0 deletions docs/apache-airflow-providers-databricks/operators/jobs_create.rst
@@ -0,0 +1,91 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
DatabricksCreateJobsOperator
============================

Use the :class:`~airflow.providers.databricks.operators.DatabricksCreateJobsOperator` to create
(or reset) a Databricks job. This operator relies on past XComs to remember the ``job_id`` that
was created so that repeated calls with this operator will update the existing job rather than
creating new ones. When paired with the DatabricksRunNowOperator all runs will fall under the same
job within the Databricks UI.


Using the Operator
------------------

There are three ways to instantiate this operator. In the first way, you can take the JSON payload that you typically use
to call the ``api/2.1/jobs/create`` endpoint and pass it directly to our ``DatabricksCreateJobsOperator`` through the
``json`` parameter. With this approach you get full control over the underlying payload to Jobs REST API, including
execution of Databricks jobs with multiple tasks, but it's harder to detect errors because of the lack of the type checking.

The second way to accomplish the same thing is to use the named parameters of the ``DatabricksCreateJobsOperator`` directly. Note that there is exactly
one named parameter for each top level parameter in the ``api/2.1/jobs/create`` endpoint.

The third way is to use both the json parameter **AND** the named parameters. They will be merged
together. If there are conflicts during the merge, the named parameters will take precedence and
override the top level ``json`` keys.

Currently the named parameters that ``DatabricksCreateJobsOperator`` supports are:
- ``name``
- ``tags``
- ``tasks``
- ``job_clusters``
- ``email_notifications``
- ``webhook_notifications``
- ``timeout_seconds``
- ``schedule``
- ``max_concurrent_runs``
- ``git_source``
- ``access_control_list``


Examples
--------

Specifying parameters as JSON
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

An example usage of the DatabricksCreateJobsOperator is as follows:

.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_jobs_create_json]
:end-before: [END howto_operator_databricks_jobs_create_json]

Using named parameters
^^^^^^^^^^^^^^^^^^^^^^

You can also use named parameters to initialize the operator and run the job.

.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_jobs_create_named]
:end-before: [END howto_operator_databricks_jobs_create_named]

Pairing with DatabricksRunNowOperator
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You can use the ``job_id`` that is returned by the DatabricksCreateJobsOperator in the
return_value XCom as an argument to the DatabricksRunNowOperator to run the job.

.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_run_now]
:end-before: [END howto_operator_databricks_run_now]
51 changes: 51 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Expand Up @@ -107,6 +107,20 @@
}


def create_endpoint(host):
"""
Utility function to generate the create endpoint given the host.
"""
return f"https://{host}/api/2.1/jobs/create"


def reset_endpoint(host):
"""
Utility function to generate the reset endpoint given the host.
"""
return f"https://{host}/api/2.1/jobs/reset"


def run_now_endpoint(host):
"""
Utility function to generate the run now endpoint given the host.
Expand Down Expand Up @@ -387,6 +401,43 @@ def test_do_api_call_patch(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_create(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {"job_id": JOB_ID}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {"name": "test"}
job_id = self.hook.create_job(json)

assert job_id == JOB_ID

mock_requests.post.assert_called_once_with(
create_endpoint(HOST),
json={"name": "test"},
params=None,
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_reset(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {"name": "test"}
self.hook.reset_job(JOB_ID, json)

mock_requests.post.assert_called_once_with(
reset_endpoint(HOST),
json={"job_id": JOB_ID, "new_settings": {"name": "test"}},
params=None,
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_submit_run(self, mock_requests):
mock_requests.post.return_value.json.return_value = {"run_id": "1"}
Expand Down

0 comments on commit a8784e3

Please sign in to comment.