Skip to content

Commit

Permalink
fix: Pass proxies config when using ClientSecretCredential in AzureDa…
Browse files Browse the repository at this point in the history
…taLakeStorageV2Hook (apache#37103)

* fix: Pass proxies config when using ClientSecretCredential in AzureDataLakeStorageV2Hook and added 

---------

Co-authored-by: David Blain <david.blain@b-holding.be>
Co-authored-by: David Blain <david.blain@infrabel.be>
  • Loading branch information
3 people authored and abhishekbhakat committed Mar 5, 2024
1 parent e3f02b3 commit 7220432
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
12 changes: 10 additions & 2 deletions airflow/providers/microsoft/azure/hooks/data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,11 @@ def get_conn(self) -> DataLakeServiceClient: # type: ignore[override]
# use Active Directory auth
app_id = conn.login
app_secret = conn.password
credential = ClientSecretCredential(tenant, app_id, app_secret)
proxies = extra.get("proxies", {})

credential = ClientSecretCredential(
tenant_id=tenant, client_id=app_id, client_secret=app_secret, proxies=proxies
)
elif conn.password:
credential = conn.password
else:
Expand All @@ -359,8 +363,12 @@ def get_conn(self) -> DataLakeServiceClient: # type: ignore[override]
workload_identity_tenant_id=workload_identity_tenant_id,
)

account_url = extra.pop("account_url", f"https://{conn.host}.dfs.core.windows.net")

self.log.info("account_url: %s", account_url)

return DataLakeServiceClient(
account_url=f"https://{conn.host}.dfs.core.windows.net",
account_url=account_url,
credential=credential, # type: ignore[arg-type]
**extra,
)
Expand Down
39 changes: 39 additions & 0 deletions tests/providers/microsoft/azure/hooks/test_data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import PropertyMock

import pytest
from azure.core.pipeline.policies._universal import ProxyPolicy
from azure.storage.filedatalake._models import FileSystemProperties

from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeStorageV2Hook

if TYPE_CHECKING:
from azure.storage.filedatalake import DataLakeServiceClient

MODULE = "airflow.providers.microsoft.azure.hooks.data_lake"


Expand Down Expand Up @@ -297,3 +302,37 @@ def test_connection_failure(self, mock_conn):

assert status is False
assert msg == "Authentication failed."

@mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_connection")
def test_proxies_passed_to_credentials(self, mock_conn):
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
mock_conn.return_value = Connection(
conn_id=self.conn_id,
login="client_id",
password="secret",
extra={
"tenant_id": "tenant-id",
"proxies": {"https": "https://proxy:80"},
"account_url": "https://onelake.dfs.fabric.microsoft.com",
},
)
conn: DataLakeServiceClient = hook.get_conn()

assert conn is not None
assert conn.primary_endpoint == "https://onelake.dfs.fabric.microsoft.com/"
assert conn.primary_hostname == "onelake.dfs.fabric.microsoft.com"
assert conn.scheme == "https"
assert conn.url == "https://onelake.dfs.fabric.microsoft.com/"
assert conn.credential._client_id == "client_id"
assert conn.credential._client_credential == "secret"
assert self.find_policy(conn, ProxyPolicy) is not None
assert self.find_policy(conn, ProxyPolicy).proxies["https"] == "https://proxy:80"

def find_policy(self, conn, policy_type):
policies = conn.credential._client._pipeline._impl_policies
return next(
map(
lambda policy: policy._policy,
filter(lambda policy: isinstance(policy._policy, policy_type), policies),
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
# under the License.
from __future__ import annotations

import json
from unittest import mock

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.transfers.local_to_adls import LocalFilesystemToADLSOperator
from airflow.providers.microsoft.azure.transfers.local_to_adls import (
LocalFilesystemToADLSOperator,
)

TASK_ID = "test-adls-upload-operator"
FILE_SYSTEM_NAME = "Fabric"
LOCAL_PATH = "test/*"
BAD_LOCAL_PATH = "test/**"
REMOTE_PATH = "TEST-DIR"
DATA = json.dumps({"name": "David", "surname": "Blain", "gender": "M"}).encode("utf-8")


class TestADLSUploadOperator:
Expand Down

0 comments on commit 7220432

Please sign in to comment.