Skip to content

Commit

Permalink
Add endpoint_url in test_connection (#32664)
Browse files Browse the repository at this point in the history
  • Loading branch information
ieunea1128 committed Jul 24, 2023
1 parent 8012c9f commit 282854b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Expand Up @@ -809,7 +809,11 @@ def test_connection(self):
"""
try:
session = self.get_session()
conn_info = session.client("sts").get_caller_identity()
test_endpoint_url = self.conn_config.extra_config.get("test_endpoint_url")
conn_info = session.client(
"sts",
endpoint_url=test_endpoint_url,
).get_caller_identity()
metadata = conn_info.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") != 200:
try:
Expand Down
20 changes: 19 additions & 1 deletion tests/providers/amazon/aws/hooks/test_base_aws.py
Expand Up @@ -295,6 +295,7 @@ def test_get_credentials_from_role_arn(self, conn_id, conn_extra, region_name):
conn = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=extra)
sf = BaseSessionFactory(conn=conn)
session = sf.create_session()

assert session.region_name == region_name
# Validate method of botocore credentials provider.
# It shouldn't be 'explicit' which refers in this case to initial credentials.
Expand Down Expand Up @@ -858,7 +859,7 @@ def test_hook_connection_test_failed(self, mock_boto3_session):
result, message = hook.test_connection()
assert not result
assert message == json.dumps(response_metadata)
mock_sts_client.assert_called_once_with("sts")
mock_sts_client.assert_called_once_with("sts", endpoint_url=None)

def mock_error():
raise ConnectionError("Test Error")
Expand All @@ -872,6 +873,23 @@ def mock_error():

assert hook.client_type == "ec2"

@mock_sts
@pytest.mark.parametrize(
"test_endpoint_url, result_url",
[
(None, "https://sts.amazonaws.com"),
("https://sts.us-east-1.amazonaws.com", "https://sts.us-east-1.amazonaws.com"),
],
)
def test_hook_connection_endpoint_url_valid(self, test_endpoint_url, result_url):
"""Test if test_endpoint_url is valid in test connection"""
conn = AwsConnectionWrapper.from_connection_metadata(conn_id=MOCK_AWS_CONN_ID)
sf = BaseSessionFactory(conn=conn)
session = sf.create_session()
client = session.client("sts", endpoint_url=test_endpoint_url)

assert client._endpoint.host == result_url

@mock.patch.dict(os.environ, {f"AIRFLOW_CONN_{MOCK_AWS_CONN_ID.upper()}": "aws://"})
def test_conn_config_conn_id_exists(self):
"""Test retrieve connection config if aws_conn_id exists."""
Expand Down

0 comments on commit 282854b

Please sign in to comment.