diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 9e4ecaff372b8..ac8d9511e0676 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -48,6 +48,8 @@ USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'} +RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] + # https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token # https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" @@ -64,7 +66,9 @@ class RunState: """Utility class for the run state concept of Databricks runs.""" - def __init__(self, life_cycle_state: str, result_state: str, state_message: str) -> None: + def __init__( + self, life_cycle_state: str, result_state: str = '', state_message: str = '', *args, **kwargs + ) -> None: self.life_cycle_state = life_cycle_state self.result_state = result_state self.state_message = state_message @@ -131,7 +135,11 @@ def __init__( ) -> None: super().__init__() self.databricks_conn_id = databricks_conn_id - self.databricks_conn = None + self.databricks_conn = self.get_connection(databricks_conn_id) + if 'host' in self.databricks_conn.extra_dejson: + self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) + else: + self.host = self._parse_host(self.databricks_conn.host) self.timeout_seconds = timeout_seconds if retry_limit < 1: raise ValueError('Retry limit must be greater than equal to 1') @@ -173,13 +181,11 @@ def _get_aad_token(self, resource: str) -> str: :param resource: resource to issue token to :return: AAD token, or raise an exception """ - if resource in self.aad_tokens: - d = self.aad_tokens[resource] - now = int(time.time()) - if d['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): # it expires in more than 2 minutes - return d['token'] - self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") + aad_token = self.aad_tokens.get(resource) + if aad_token and self._is_aad_token_valid(aad_token): + return aad_token['token'] + self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...') attempt_num = 1 while True: try: @@ -235,21 +241,53 @@ def _get_aad_token(self, resource: str) -> str: attempt_num += 1 sleep(self.retry_delay) - def _fill_aad_tokens(self, headers: dict) -> str: + def _get_aad_headers(self) -> dict: """ - Fills headers if necessary (SPN is outside of the workspace) and generates AAD token - :param headers: dictionary with headers to fill-in - :return: AAD token + Fills AAD headers if necessary (SPN is outside of the workspace) + :return: dictionary with filled AAD headers """ - # SP is outside of the workspace + headers = {} if 'azure_resource_id' in self.databricks_conn.extra_dejson: mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ 'azure_resource_id' ] headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token + return headers - return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) + @staticmethod + def _is_aad_token_valid(aad_token: dict) -> bool: + """ + Utility function to check AAD token hasn't expired yet + :param aad_token: dict with properties of AAD token + :type aad_token: dict + :return: true if token is valid, false otherwise + :rtype: bool + """ + now = int(time.time()) + if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): + return True + return False + + @staticmethod + def _check_azure_metadata_service() -> None: + """ + Check for Azure Metadata Service + https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service + """ + try: + jsn = requests.get( + AZURE_METADATA_SERVICE_TOKEN_URL, + params={"api-version": "2021-02-01"}, + headers={"Metadata": "true"}, + timeout=2, + ).json() + if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: + raise AirflowException( + f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" + ) + except (requests_exceptions.RequestException, ValueError) as e: + raise AirflowException(f"Can't reach Azure Metadata Service: {e}") def _do_api_call(self, endpoint_info, json): """ @@ -265,14 +303,10 @@ def _do_api_call(self, endpoint_info, json): :rtype: dict """ method, endpoint = endpoint_info + url = f'https://{self.host}/{endpoint}' - self.databricks_conn = self.get_connection(self.databricks_conn_id) - - headers = USER_AGENT_HEADER.copy() - if 'host' in self.databricks_conn.extra_dejson: - host = self._parse_host(self.databricks_conn.extra_dejson['host']) - else: - host = self.databricks_conn.host + aad_headers = self._get_aad_headers() + headers = {**USER_AGENT_HEADER.copy(), **aad_headers} if 'token' in self.databricks_conn.extra_dejson: self.log.info( @@ -285,34 +319,16 @@ def _do_api_call(self, endpoint_info, json): elif 'azure_tenant_id' in self.databricks_conn.extra_dejson: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Azure SPN credentials aren't provided") - - self.log.info('Using AAD Token for SPN. ') - auth = _TokenAuth(self._fill_aad_tokens(headers)) + self.log.info('Using AAD Token for SPN.') + auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): self.log.info('Using AAD Token for managed identity.') - # check for Azure Metadata Service - # https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service - try: - jsn = requests.get( - AZURE_METADATA_SERVICE_TOKEN_URL, - params={"api-version": "2021-02-01"}, - headers={"Metadata": "true"}, - timeout=2, - ).json() - if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: - raise AirflowException( - f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" - ) - except (requests_exceptions.RequestException, ValueError) as e: - raise AirflowException(f"Can't reach Azure Metadata Service: {e}") - - auth = _TokenAuth(self._fill_aad_tokens(headers)) + self._check_azure_metadata_service() + auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) else: self.log.info('Using basic auth.') auth = (self.databricks_conn.login, self.databricks_conn.password) - url = f'https://{self._parse_host(host)}/{endpoint}' - if method == 'GET': request_func = requests.get elif method == 'POST': @@ -356,31 +372,31 @@ def _do_api_call(self, endpoint_info, json): def _log_request_error(self, attempt_num: int, error: str) -> None: self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error) - def run_now(self, json: dict) -> str: + def run_now(self, json: dict) -> int: """ Utility function to call the ``api/2.0/jobs/run-now`` endpoint. :param json: The data used in the body of the request to the ``run-now`` endpoint. :type json: dict - :return: the run_id as a string + :return: the run_id as an int :rtype: str """ response = self._do_api_call(RUN_NOW_ENDPOINT, json) return response['run_id'] - def submit_run(self, json: dict) -> str: + def submit_run(self, json: dict) -> int: """ Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint. :param json: The data used in the body of the request to the ``submit`` endpoint. :type json: dict - :return: the run_id as a string + :return: the run_id as an int :rtype: str """ response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json) return response['run_id'] - def get_run_page_url(self, run_id: str) -> str: + def get_run_page_url(self, run_id: int) -> str: """ Retrieves run_page_url. @@ -391,19 +407,19 @@ def get_run_page_url(self, run_id: str) -> str: response = self._do_api_call(GET_RUN_ENDPOINT, json) return response['run_page_url'] - def get_job_id(self, run_id: str) -> str: + def get_job_id(self, run_id: int) -> int: """ Retrieves job_id from run_id. :param run_id: id of the run - :type run_id: str + :type run_id: int :return: Job id for given Databricks run """ json = {'run_id': run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) return response['job_id'] - def get_run_state(self, run_id: str) -> RunState: + def get_run_state(self, run_id: int) -> RunState: """ Retrieves run state of the run. @@ -421,13 +437,9 @@ def get_run_state(self, run_id: str) -> RunState: json = {'run_id': run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) state = response['state'] - life_cycle_state = state['life_cycle_state'] - # result_state may not be in the state if not terminal - result_state = state.get('result_state', None) - state_message = state['state_message'] - return RunState(life_cycle_state, result_state, state_message) + return RunState(**state) - def get_run_state_str(self, run_id: str) -> str: + def get_run_state_str(self, run_id: int) -> str: """ Return the string representation of RunState. @@ -440,7 +452,7 @@ def get_run_state_str(self, run_id: str) -> str: ) return run_state_str - def get_run_state_lifecycle(self, run_id: str) -> str: + def get_run_state_lifecycle(self, run_id: int) -> str: """ Returns the lifecycle state of the run @@ -449,7 +461,7 @@ def get_run_state_lifecycle(self, run_id: str) -> str: """ return self.get_run_state(run_id).life_cycle_state - def get_run_state_result(self, run_id: str) -> str: + def get_run_state_result(self, run_id: int) -> str: """ Returns the resulting state of the run @@ -458,7 +470,7 @@ def get_run_state_result(self, run_id: str) -> str: """ return self.get_run_state(run_id).result_state - def get_run_state_message(self, run_id: str) -> str: + def get_run_state_message(self, run_id: int) -> str: """ Returns the state message for the run @@ -467,7 +479,7 @@ def get_run_state_message(self, run_id: str) -> str: """ return self.get_run_state(run_id).state_message - def cancel_run(self, run_id: str) -> None: + def cancel_run(self, run_id: int) -> None: """ Cancels the run. @@ -531,9 +543,6 @@ def _retryable_error(exception) -> bool: ) -RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] - - class _TokenAuth(AuthBase): """ Helper class for requests Auth field. AuthBase requires you to implement the __call__ diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index a5f0fb467fe64..ea688e87dc955 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -19,6 +19,7 @@ import itertools import json +import time import unittest from unittest import mock @@ -31,9 +32,11 @@ from airflow.providers.databricks.hooks.databricks import ( AZURE_DEFAULT_AD_ENDPOINT, AZURE_MANAGEMENT_ENDPOINT, + AZURE_METADATA_SERVICE_TOKEN_URL, AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, SUBMIT_RUN_ENDPOINT, + TOKEN_REFRESH_LEAD_TIME, DatabricksHook, RunState, ) @@ -63,7 +66,7 @@ } NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] -RESULT_STATE = None # type: None +RESULT_STATE = '' LIBRARIES = [ {"jar": "dbfs:/mnt/libraries/library.jar"}, {"maven": {"coordinates": "org.jsoup:jsoup:1.7.2", "exclusions": ["slf4j:slf4j"]}}, @@ -520,6 +523,14 @@ def test_uninstall_libs_on_cluster(self, mock_requests): timeout=self.hook.timeout_seconds, ) + def test_is_aad_token_valid_returns_true(self): + aad_token = {'token': 'my_token', 'expires_on': int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10} + self.assertTrue(self.hook._is_aad_token_valid(aad_token)) + + def test_is_aad_token_valid_returns_false(self): + aad_token = {'token': 'my_token', 'expires_on': int(time.time())} + self.assertFalse(self.hook._is_aad_token_valid(aad_token)) + class TestDatabricksHookToken(unittest.TestCase): """ @@ -762,3 +773,46 @@ def test_submit_run(self, mock_requests): assert kwargs['auth'].token == TOKEN assert kwargs['headers']['X-Databricks-Azure-Workspace-Resource-Id'] == '/Some/resource' assert kwargs['headers']['X-Databricks-Azure-SP-Management-Token'] == TOKEN + + +class TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase): + """ + Tests for DatabricksHook when auth is done with AAD leveraging Managed Identity authentication + """ + + @provide_session + def setUp(self, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.host = HOST + conn.extra = json.dumps( + { + 'use_azure_managed_identity': True, + } + ) + session.commit() + self.hook = DatabricksHook() + + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_submit_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.side_effect = [ + create_successful_response_mock({'compute': {'azEnvironment': 'AZUREPUBLICCLOUD'}}), + create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), + ] + mock_requests.post.side_effect = [ + create_successful_response_mock({'run_id': '1'}), + ] + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} + run_id = self.hook.submit_run(data) + + ad_call_args = mock_requests.method_calls[0] + assert ad_call_args[1][0] == AZURE_METADATA_SERVICE_TOKEN_URL + assert ad_call_args[2]['params']['api-version'] > '2018-02-01' + assert ad_call_args[2]['headers']['Metadata'] == 'true' + + assert run_id == '1' + args = mock_requests.post.call_args + kwargs = args[1] + assert kwargs['auth'].token == TOKEN