Skip to content

Commit

Permalink
Fix redshift_data hook to consider access and secret key from the c…
Browse files Browse the repository at this point in the history
…onnections object not from extra (#746)

* Fix to considered aws access and secret key via login

* Add session token args

* Fix mypy issue

Fix test case
  • Loading branch information
bharanidharan14 committed Nov 4, 2022
1 parent 7923acd commit f4b5db7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
9 changes: 9 additions & 0 deletions astronomer/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def get_conn_params(self) -> Dict[str, Union[str, int]]:
if "secret_access_key" in extra_config
else extra_config["aws_secret_access_key"]
)
elif connection_object.login:
conn_params["aws_access_key_id"] = connection_object.login
conn_params["aws_secret_access_key"] = connection_object.password
else:
raise AirflowException("Required access_key_id, aws_secret_access_key")

Expand All @@ -88,6 +91,12 @@ def get_conn_params(self) -> Dict[str, Union[str, int]]:
else:
raise AirflowException("Required Region name is missing !")

if "aws_session_token" in extra_config:
self.log.info(
"session token retrieved from extra, please note you are responsible for renewing these.",
)
conn_params["aws_session_token"] = extra_config["aws_session_token"]

if "cluster_identifier" in extra_config:
self.log.info("Retrieving cluster_identifier from Connection.extra_config['cluster_identifier']")
conn_params["cluster_identifier"] = extra_config["cluster_identifier"]
Expand Down
41 changes: 41 additions & 0 deletions tests/amazon/aws/hooks/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,47 @@ def test_get_conn_params(mock_get_connection, connection_details, expected_outpu
assert response == expected_output


@pytest.mark.parametrize(
"mock_login, mock_pwd, connection_details, expected_output",
[
(
"test",
"test",
{
"db_user": "test_user",
"cluster_identifier": "test_cluster",
"region": "us-east-2",
"database": "test-redshift_database",
"aws_session_token": "test",
},
{
"aws_access_key_id": "test",
"aws_secret_access_key": "test",
"aws_session_token": "test",
"db_user": "test_user",
"cluster_identifier": "test_cluster",
"region_name": "us-east-2",
"database": "test-redshift_database",
},
),
],
)
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_connection")
def test_get_conn_params_with_login_pwd(
mock_get_connection, mock_login, mock_pwd, connection_details, expected_output
):
"""
Test get_conn_params by mocking the AWS secret and access key and session token,
passing access and secret key in connection login and password instead passing in extra
"""
mock_conn = Connection(login=mock_login, password=mock_pwd, extra=json.dumps(connection_details))
mock_get_connection.return_value = mock_conn

hook = RedshiftDataHook(client_type="redshift-data")
response = hook.get_conn_params()
assert response == expected_output


@pytest.mark.parametrize(
"connection_details, test",
[
Expand Down

0 comments on commit f4b5db7

Please sign in to comment.