Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RedshiftDataHook to accept access and secret keys from the connections object also. #746

Merged
merged 3 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions astronomer/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def get_conn_params(self) -> Dict[str, Union[str, int]]:
if "secret_access_key" in extra_config
else extra_config["aws_secret_access_key"]
)
elif connection_object.login:
conn_params["aws_access_key_id"] = connection_object.login
conn_params["aws_secret_access_key"] = connection_object.password
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise AirflowException("Required access_key_id, aws_secret_access_key")

Expand All @@ -88,6 +91,12 @@ def get_conn_params(self) -> Dict[str, Union[str, int]]:
else:
raise AirflowException("Required Region name is missing !")

if "aws_session_token" in extra_config:
self.log.info(
"session token retrieved from extra, please note you are responsible for renewing these.",
)
conn_params["aws_session_token"] = extra_config["aws_session_token"]

if "cluster_identifier" in extra_config:
self.log.info("Retrieving cluster_identifier from Connection.extra_config['cluster_identifier']")
conn_params["cluster_identifier"] = extra_config["cluster_identifier"]
Expand Down
41 changes: 41 additions & 0 deletions tests/amazon/aws/hooks/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,47 @@ def test_get_conn_params(mock_get_connection, connection_details, expected_outpu
assert response == expected_output


@pytest.mark.parametrize(
"mock_login, mock_pwd, connection_details, expected_output",
[
(
"test",
"test",
{
"db_user": "test_user",
"cluster_identifier": "test_cluster",
"region": "us-east-2",
"database": "test-redshift_database",
"aws_session_token": "test",
},
{
"aws_access_key_id": "test",
"aws_secret_access_key": "test",
"aws_session_token": "test",
"db_user": "test_user",
"cluster_identifier": "test_cluster",
"region_name": "us-east-2",
"database": "test-redshift_database",
},
),
],
)
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_connection")
def test_get_conn_params_with_login_pwd(
mock_get_connection, mock_login, mock_pwd, connection_details, expected_output
):
"""
Test get_conn_params by mocking the AWS secret and access key and session token,
passing access and secret key in connection login and password instead passing in extra
"""
mock_conn = Connection(login=mock_login, password=mock_pwd, extra=json.dumps(connection_details))
mock_get_connection.return_value = mock_conn

hook = RedshiftDataHook(client_type="redshift-data")
response = hook.get_conn_params()
assert response == expected_output


@pytest.mark.parametrize(
"connection_details, test",
[
Expand Down