Skip to content

Commit

Permalink
Refactor DatabricksHook (#19835)
Browse files Browse the repository at this point in the history
  • Loading branch information
eskarimov committed Dec 5, 2021
1 parent af28b41 commit 728e94a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 65 deletions.
137 changes: 73 additions & 64 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@

USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}

RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']

# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
Expand All @@ -64,7 +66,9 @@
class RunState:
"""Utility class for the run state concept of Databricks runs."""

def __init__(self, life_cycle_state: str, result_state: str, state_message: str) -> None:
def __init__(
self, life_cycle_state: str, result_state: str = '', state_message: str = '', *args, **kwargs
) -> None:
self.life_cycle_state = life_cycle_state
self.result_state = result_state
self.state_message = state_message
Expand Down Expand Up @@ -131,7 +135,11 @@ def __init__(
) -> None:
super().__init__()
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = None
self.databricks_conn = self.get_connection(databricks_conn_id)
if 'host' in self.databricks_conn.extra_dejson:
self.host = self._parse_host(self.databricks_conn.extra_dejson['host'])
else:
self.host = self._parse_host(self.databricks_conn.host)
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError('Retry limit must be greater than equal to 1')
Expand Down Expand Up @@ -173,13 +181,11 @@ def _get_aad_token(self, resource: str) -> str:
:param resource: resource to issue token to
:return: AAD token, or raise an exception
"""
if resource in self.aad_tokens:
d = self.aad_tokens[resource]
now = int(time.time())
if d['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): # it expires in more than 2 minutes
return d['token']
self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
aad_token = self.aad_tokens.get(resource)
if aad_token and self._is_aad_token_valid(aad_token):
return aad_token['token']

self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
attempt_num = 1
while True:
try:
Expand Down Expand Up @@ -235,21 +241,53 @@ def _get_aad_token(self, resource: str) -> str:
attempt_num += 1
sleep(self.retry_delay)

def _fill_aad_tokens(self, headers: dict) -> str:
def _get_aad_headers(self) -> dict:
"""
Fills headers if necessary (SPN is outside of the workspace) and generates AAD token
:param headers: dictionary with headers to fill-in
:return: AAD token
Fills AAD headers if necessary (SPN is outside of the workspace)
:return: dictionary with filled AAD headers
"""
# SP is outside of the workspace
headers = {}
if 'azure_resource_id' in self.databricks_conn.extra_dejson:
mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[
'azure_resource_id'
]
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers

return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
@staticmethod
def _is_aad_token_valid(aad_token: dict) -> bool:
"""
Utility function to check AAD token hasn't expired yet
:param aad_token: dict with properties of AAD token
:type aad_token: dict
:return: true if token is valid, false otherwise
:rtype: bool
"""
now = int(time.time())
if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME):
return True
return False

@staticmethod
def _check_azure_metadata_service() -> None:
"""
Check for Azure Metadata Service
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
"""
try:
jsn = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
).json()
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

def _do_api_call(self, endpoint_info, json):
"""
Expand All @@ -265,14 +303,10 @@ def _do_api_call(self, endpoint_info, json):
:rtype: dict
"""
method, endpoint = endpoint_info
url = f'https://{self.host}/{endpoint}'

self.databricks_conn = self.get_connection(self.databricks_conn_id)

headers = USER_AGENT_HEADER.copy()
if 'host' in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson['host'])
else:
host = self.databricks_conn.host
aad_headers = self._get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}

if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
Expand All @@ -285,34 +319,16 @@ def _do_api_call(self, endpoint_info, json):
elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")

self.log.info('Using AAD Token for SPN. ')
auth = _TokenAuth(self._fill_aad_tokens(headers))
self.log.info('Using AAD Token for SPN.')
auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE))
elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
self.log.info('Using AAD Token for managed identity.')
# check for Azure Metadata Service
# https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
try:
jsn = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
).json()
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

auth = _TokenAuth(self._fill_aad_tokens(headers))
self._check_azure_metadata_service()
auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE))
else:
self.log.info('Using basic auth.')
auth = (self.databricks_conn.login, self.databricks_conn.password)

url = f'https://{self._parse_host(host)}/{endpoint}'

if method == 'GET':
request_func = requests.get
elif method == 'POST':
Expand Down Expand Up @@ -356,31 +372,31 @@ def _do_api_call(self, endpoint_info, json):
def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error)

def run_now(self, json: dict) -> str:
def run_now(self, json: dict) -> int:
"""
Utility function to call the ``api/2.0/jobs/run-now`` endpoint.
:param json: The data used in the body of the request to the ``run-now`` endpoint.
:type json: dict
:return: the run_id as a string
:return: the run_id as an int
:rtype: str
"""
response = self._do_api_call(RUN_NOW_ENDPOINT, json)
return response['run_id']

def submit_run(self, json: dict) -> str:
def submit_run(self, json: dict) -> int:
"""
Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint.
:param json: The data used in the body of the request to the ``submit`` endpoint.
:type json: dict
:return: the run_id as a string
:return: the run_id as an int
:rtype: str
"""
response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
return response['run_id']

def get_run_page_url(self, run_id: str) -> str:
def get_run_page_url(self, run_id: int) -> str:
"""
Retrieves run_page_url.
Expand All @@ -391,19 +407,19 @@ def get_run_page_url(self, run_id: str) -> str:
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response['run_page_url']

def get_job_id(self, run_id: str) -> str:
def get_job_id(self, run_id: int) -> int:
"""
Retrieves job_id from run_id.
:param run_id: id of the run
:type run_id: str
:type run_id: int
:return: Job id for given Databricks run
"""
json = {'run_id': run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response['job_id']

def get_run_state(self, run_id: str) -> RunState:
def get_run_state(self, run_id: int) -> RunState:
"""
Retrieves run state of the run.
Expand All @@ -421,13 +437,9 @@ def get_run_state(self, run_id: str) -> RunState:
json = {'run_id': run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
state = response['state']
life_cycle_state = state['life_cycle_state']
# result_state may not be in the state if not terminal
result_state = state.get('result_state', None)
state_message = state['state_message']
return RunState(life_cycle_state, result_state, state_message)
return RunState(**state)

def get_run_state_str(self, run_id: str) -> str:
def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.
Expand All @@ -440,7 +452,7 @@ def get_run_state_str(self, run_id: str) -> str:
)
return run_state_str

def get_run_state_lifecycle(self, run_id: str) -> str:
def get_run_state_lifecycle(self, run_id: int) -> str:
"""
Returns the lifecycle state of the run
Expand All @@ -449,7 +461,7 @@ def get_run_state_lifecycle(self, run_id: str) -> str:
"""
return self.get_run_state(run_id).life_cycle_state

def get_run_state_result(self, run_id: str) -> str:
def get_run_state_result(self, run_id: int) -> str:
"""
Returns the resulting state of the run
Expand All @@ -458,7 +470,7 @@ def get_run_state_result(self, run_id: str) -> str:
"""
return self.get_run_state(run_id).result_state

def get_run_state_message(self, run_id: str) -> str:
def get_run_state_message(self, run_id: int) -> str:
"""
Returns the state message for the run
Expand All @@ -467,7 +479,7 @@ def get_run_state_message(self, run_id: str) -> str:
"""
return self.get_run_state(run_id).state_message

def cancel_run(self, run_id: str) -> None:
def cancel_run(self, run_id: int) -> None:
"""
Cancels the run.
Expand Down Expand Up @@ -531,9 +543,6 @@ def _retryable_error(exception) -> bool:
)


RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']


class _TokenAuth(AuthBase):
"""
Helper class for requests Auth field. AuthBase requires you to implement the __call__
Expand Down
56 changes: 55 additions & 1 deletion tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import itertools
import json
import time
import unittest
from unittest import mock

Expand All @@ -31,9 +32,11 @@
from airflow.providers.databricks.hooks.databricks import (
AZURE_DEFAULT_AD_ENDPOINT,
AZURE_MANAGEMENT_ENDPOINT,
AZURE_METADATA_SERVICE_TOKEN_URL,
AZURE_TOKEN_SERVICE_URL,
DEFAULT_DATABRICKS_SCOPE,
SUBMIT_RUN_ENDPOINT,
TOKEN_REFRESH_LEAD_TIME,
DatabricksHook,
RunState,
)
Expand Down Expand Up @@ -63,7 +66,7 @@
}
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
JAR_PARAMS = ["param1", "param2"]
RESULT_STATE = None # type: None
RESULT_STATE = ''
LIBRARIES = [
{"jar": "dbfs:/mnt/libraries/library.jar"},
{"maven": {"coordinates": "org.jsoup:jsoup:1.7.2", "exclusions": ["slf4j:slf4j"]}},
Expand Down Expand Up @@ -520,6 +523,14 @@ def test_uninstall_libs_on_cluster(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

def test_is_aad_token_valid_returns_true(self):
aad_token = {'token': 'my_token', 'expires_on': int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10}
self.assertTrue(self.hook._is_aad_token_valid(aad_token))

def test_is_aad_token_valid_returns_false(self):
aad_token = {'token': 'my_token', 'expires_on': int(time.time())}
self.assertFalse(self.hook._is_aad_token_valid(aad_token))


class TestDatabricksHookToken(unittest.TestCase):
"""
Expand Down Expand Up @@ -762,3 +773,46 @@ def test_submit_run(self, mock_requests):
assert kwargs['auth'].token == TOKEN
assert kwargs['headers']['X-Databricks-Azure-Workspace-Resource-Id'] == '/Some/resource'
assert kwargs['headers']['X-Databricks-Azure-SP-Management-Token'] == TOKEN


class TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase):
"""
Tests for DatabricksHook when auth is done with AAD leveraging Managed Identity authentication
"""

@provide_session
def setUp(self, session=None):
conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first()
conn.host = HOST
conn.extra = json.dumps(
{
'use_azure_managed_identity': True,
}
)
session.commit()
self.hook = DatabricksHook()

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_submit_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.side_effect = [
create_successful_response_mock({'compute': {'azEnvironment': 'AZUREPUBLICCLOUD'}}),
create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)),
]
mock_requests.post.side_effect = [
create_successful_response_mock({'run_id': '1'}),
]
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
run_id = self.hook.submit_run(data)

ad_call_args = mock_requests.method_calls[0]
assert ad_call_args[1][0] == AZURE_METADATA_SERVICE_TOKEN_URL
assert ad_call_args[2]['params']['api-version'] > '2018-02-01'
assert ad_call_args[2]['headers']['Metadata'] == 'true'

assert run_id == '1'
args = mock_requests.post.call_args
kwargs = args[1]
assert kwargs['auth'].token == TOKEN

0 comments on commit 728e94a

Please sign in to comment.