Skip to content

Commit

Permalink
Add install/uninstall api to databricks hook (#12316)
Browse files Browse the repository at this point in the history
- adding install Databricks API to databricks hook(api/2.0/libraries/install)

- adding uninstall Databricks API to databricks hook (2.0/libraries/uninstall)
  • Loading branch information
hnaoto committed Nov 13, 2020
1 parent 75f25bd commit b027223
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
25 changes: 25 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
CANCEL_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/cancel')
USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}

INSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/install')
UNINSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/uninstall')


class RunState:
"""Utility class for the run state concept of Databricks runs."""
Expand Down Expand Up @@ -311,6 +314,28 @@ def terminate_cluster(self, json: dict) -> None:
"""
self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json)

def install(self, json: dict) -> None:
"""
Install libraries on the cluster.
Utility function to call the ``2.0/libraries/install`` endpoint.
:param json: json dictionary containing cluster_id and an array of library
:type json: dict
"""
self._do_api_call(INSTALL_LIBS_ENDPOINT, json)

def uninstall(self, json: dict) -> None:
"""
Uninstall libraries on the cluster.
Utility function to call the ``2.0/libraries/uninstall`` endpoint.
:param json: json dictionary containing cluster_id and an array of library
:type json: dict
"""
self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json)


def _retryable_error(exception) -> bool:
return (
Expand Down
56 changes: 56 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
JAR_PARAMS = ["param1", "param2"]
RESULT_STATE = None # type: None
LIBRARIES = [
{"jar": "dbfs:/mnt/libraries/library.jar"},
{"maven": {"coordinates": "org.jsoup:jsoup:1.7.2", "exclusions": ["slf4j:slf4j"]}},
]


def run_now_endpoint(host):
Expand Down Expand Up @@ -106,6 +110,20 @@ def terminate_cluster_endpoint(host):
return f'https://{host}/api/2.0/clusters/delete'


def install_endpoint(host):
"""
Utility function to generate the install endpoint given the host.
"""
return f'https://{host}/api/2.0/libraries/install'


def uninstall_endpoint(host):
"""
Utility function to generate the uninstall endpoint given the host.
"""
return f'https://{host}/api/2.0/libraries/uninstall'


def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
Expand Down Expand Up @@ -424,6 +442,44 @@ def test_terminate_cluster(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_install_libs_on_cluster(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

data = {'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES}
self.hook.install(data)

mock_requests.post.assert_called_once_with(
install_endpoint(HOST),
json={'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES},
params=None,
auth=(LOGIN, PASSWORD),
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds,
)

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_uninstall_libs_on_cluster(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

data = {'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES}
self.hook.uninstall(data)

mock_requests.post.assert_called_once_with(
uninstall_endpoint(HOST),
json={'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES},
params=None,
auth=(LOGIN, PASSWORD),
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds,
)


class TestDatabricksHookToken(unittest.TestCase):
"""
Expand Down

0 comments on commit b027223

Please sign in to comment.