Skip to content

Commit

Permalink
Added sas_token var to BlobServiceClient return. Updated tests (#19234)
Browse files Browse the repository at this point in the history
  • Loading branch information
ReadytoRocc committed Oct 27, 2021
1 parent 3c08c02 commit 61d0093
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_conn(self) -> BlobServiceClient:
return BlobServiceClient(account_url=conn.host, credential=token_credential)
sas_token = extra.get('sas_token') or extra.get('extra__wasb__sas_token')
if sas_token and sas_token.startswith('https'):
return BlobServiceClient(account_url=extra.get('sas_token'))
return BlobServiceClient(account_url=sas_token)
if sas_token and not sas_token.startswith('https'):
return BlobServiceClient(account_url=f"https://{conn.login}.blob.core.windows.net/" + sas_token)

Expand Down
44 changes: 41 additions & 3 deletions tests/providers/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
from azure.identity import ManagedIdentityCredential
from azure.storage.blob import BlobServiceClient
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.models import Connection
Expand All @@ -47,6 +48,9 @@ def setUp(self):
self.shared_key_conn_id = 'azure_shared_key_test'
self.ad_conn_id = 'azure_AD_test'
self.sas_conn_id = 'sas_token_id'
self.extra__wasb__sas_conn_id = 'extra__sas_token_id'
self.http_sas_conn_id = 'http_sas_token_id'
self.extra__wasb__http_sas_conn_id = 'extra__http_sas_token_id'
self.public_read_conn_id = 'pub_read_id'
self.managed_identity_conn_id = 'managed_identity'

Expand Down Expand Up @@ -95,6 +99,27 @@ def setUp(self):
extra=json.dumps({'sas_token': 'token'}),
)
)
db.merge_conn(
Connection(
conn_id=self.extra__wasb__sas_conn_id,
conn_type=self.connection_type,
extra=json.dumps({'extra__wasb__sas_token': 'token'}),
)
)
db.merge_conn(
Connection(
conn_id=self.http_sas_conn_id,
conn_type=self.connection_type,
extra=json.dumps({'sas_token': 'https://login.blob.core.windows.net/token'}),
)
)
db.merge_conn(
Connection(
conn_id=self.extra__wasb__http_sas_conn_id,
conn_type=self.connection_type,
extra=json.dumps({'extra__wasb__sas_token': 'https://login.blob.core.windows.net/token'}),
)
)

def test_key(self):
hook = WasbHook(wasb_conn_id='wasb_test_key')
Expand All @@ -119,9 +144,22 @@ def test_managed_identity(self):
self.assertIsInstance(hook.get_conn(), BlobServiceClient)
self.assertIsInstance(hook.get_conn().credential, ManagedIdentityCredential)

def test_sas_token_connection(self):
hook = WasbHook(wasb_conn_id=self.sas_conn_id)
assert isinstance(hook.get_conn(), BlobServiceClient)
@parameterized.expand(
[
('sas_conn_id', 'sas_token'),
('extra__wasb__sas_conn_id', 'extra__wasb__sas_token'),
('http_sas_conn_id', 'sas_token'),
('extra__wasb__http_sas_conn_id', 'extra__wasb__sas_token'),
],
)
def test_sas_token_connection(self, conn_id_str, extra_key):
conn_id = self.__getattribute__(conn_id_str)
hook = WasbHook(wasb_conn_id=conn_id)
conn = hook.get_conn()
hook_conn = hook.get_connection(hook.conn_id)
sas_token = hook_conn.extra_dejson[extra_key]
assert isinstance(conn, BlobServiceClient)
assert conn.url.endswith(sas_token + '/')

@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
def test_check_for_blob(self, mock_service):
Expand Down

0 comments on commit 61d0093

Please sign in to comment.