Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks - allow Azure SP authentication on other Azure clouds #19722

Merged
merged 2 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}

# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
AZURE_TOKEN_SERVICE_URL = "https://login.microsoftonline.com/{}/oauth2/token"
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token"
# https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token
AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token"
AZURE_METADATA_SERVICE_INSTANCE_URL = "http://169.254.169.254/metadata/instance"
Expand Down Expand Up @@ -200,8 +202,11 @@ def _get_aad_token(self, resource: str) -> str:
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
)
resp = requests.post(
AZURE_TOKEN_SERVICE_URL.format(tenant_id),
AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={**USER_AGENT_HEADER, 'Content-Type': 'application/x-www-form-urlencoded'},
timeout=self.aad_timeout_seconds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Extra (optional)
* ``azure_tenant_id``: ID of the Azure Active Directory tenant
* ``azure_resource_id``: optional Resource ID of the Azure Databricks workspace (required if Service Principal isn't
a user inside workspace)
* ``azure_ad_endpoint``: optional host name of Azure AD endpoint if you're using special `Azure Cloud (GovCloud, China, Germany) <https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints>`_. The value must contain a protocol. For example: ``https://login.microsoftonline.de``.

Following parameters are necessary if using authentication with AAD token for Azure managed identity:

Expand Down
63 changes: 62 additions & 1 deletion tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.databricks.hooks.databricks import (
AZURE_DEFAULT_AD_ENDPOINT,
AZURE_MANAGEMENT_ENDPOINT,
AZURE_TOKEN_SERVICE_URL,
DEFAULT_DATABRICKS_SCOPE,
SUBMIT_RUN_ENDPOINT,
DatabricksHook,
Expand Down Expand Up @@ -638,6 +640,53 @@ def test_submit_run(self, mock_requests):
assert kwargs['auth'].token == TOKEN


class TestDatabricksHookAadTokenOtherClouds(unittest.TestCase):
"""
Tests for DatabricksHook when auth is done with AAD token for SP as user inside workspace and
using non-global Azure cloud (China, GovCloud, Germany)
"""

@provide_session
def setUp(self, session=None):
self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d'
self.ad_endpoint = 'https://login.microsoftonline.de'
self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first()
conn.login = self.client_id
conn.password = 'secret'
conn.extra = json.dumps(
{
'host': HOST,
'azure_tenant_id': self.tenant_id,
'azure_ad_endpoint': self.ad_endpoint,
}
)
session.commit()
self.hook = DatabricksHook()

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_submit_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.side_effect = [
create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)),
create_successful_response_mock({'run_id': '1'}),
]
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
run_id = self.hook.submit_run(data)

ad_call_args = mock_requests.method_calls[0]
assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id)
assert ad_call_args[2]['data']['client_id'] == self.client_id
assert ad_call_args[2]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE

assert run_id == '1'
args = mock_requests.post.call_args
kwargs = args[1]
assert kwargs['auth'].token == TOKEN


class TestDatabricksHookAadTokenSpOutside(unittest.TestCase):
"""
Tests for DatabricksHook when auth is done with AAD token for SP outside of workspace.
Expand All @@ -646,7 +695,9 @@ class TestDatabricksHookAadTokenSpOutside(unittest.TestCase):
@provide_session
def setUp(self, session=None):
conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first()
conn.login = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d'
self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
conn.login = self.client_id
conn.password = 'secret'
conn.host = HOST
conn.extra = json.dumps(
Expand All @@ -671,6 +722,16 @@ def test_submit_run(self, mock_requests):
data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
run_id = self.hook.submit_run(data)

ad_call_args = mock_requests.method_calls[0]
assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id)
assert ad_call_args[2]['data']['client_id'] == self.client_id
assert ad_call_args[2]['data']['resource'] == AZURE_MANAGEMENT_ENDPOINT

ad_call_args = mock_requests.method_calls[1]
assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id)
assert ad_call_args[2]['data']['client_id'] == self.client_id
assert ad_call_args[2]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE

assert run_id == '1'
args = mock_requests.post.call_args
kwargs = args[1]
Expand Down