Skip to content

Commit

Permalink
Athena profile mapping set aws_session_token in profile only if it ex…
Browse files Browse the repository at this point in the history
…ist (#1022)

Set `aws_session_token` in the Athena profile only if it exists to avoid passing an empty string as a token to the AWS API.

Closes: #962
  • Loading branch information
pankajastro authored and tatiana committed Jun 6, 2024
1 parent 36ce7ad commit df09ea9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
4 changes: 3 additions & 1 deletion cosmos/profiles/athena/access_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ def profile(self) -> dict[str, Any | None]:
**self.profile_args,
"aws_access_key_id": self.temporary_credentials.access_key,
"aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"),
"aws_session_token": self.get_env_var_format("aws_session_token"),
}

if self.temporary_credentials.token:
profile["aws_session_token"] = self.get_env_var_format("aws_session_token")

return self.filter_null(profile)

@property
Expand Down
51 changes: 44 additions & 7 deletions tests/profiles/athena/test_athena_access_key.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"Tests for the Athena profile."
from __future__ import annotations

import json
import sys
from collections import namedtuple
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -39,28 +41,41 @@ def get_credentials(self) -> Credentials:
yield mock_aws_hook


@pytest.fixture()
def mock_athena_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""

def mock_conn_value(token: str | None = None) -> Connection:
conn = Connection(
conn_id="my_athena_connection",
conn_type="aws",
login="my_aws_access_key_id",
password="my_aws_secret_key",
extra=json.dumps(
{
"aws_session_token": "token123",
"aws_session_token": token,
"database": "my_database",
"region_name": "us-east-1",
"s3_staging_dir": "s3://my_bucket/dbt/",
"schema": "my_schema",
}
),
)
return conn


@pytest.fixture()
def mock_athena_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = mock_conn_value(token="token123")
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_athena_conn_without_token(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = mock_conn_value(token=None)
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn

Expand Down Expand Up @@ -151,6 +166,28 @@ def test_athena_profile_args(
}


@mock.patch("cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping._get_temporary_credentials")
def test_athena_profile_args_without_token(mock_temp_cred, mock_athena_conn_without_token: Connection) -> None:
"""
Tests that the profile values get set correctly for Athena.
"""
ReadOnlyCredentials = namedtuple("ReadOnlyCredentials", ["access_key", "secret_key", "token"])
credentials = ReadOnlyCredentials(access_key="my_aws_access_key", secret_key="my_aws_secret_key", token=None)
mock_temp_cred.return_value = credentials

profile_mapping = get_automatic_profile_mapping(mock_athena_conn_without_token.conn_id)

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": "my_aws_access_key",
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"database": mock_athena_conn_without_token.extra_dejson.get("database"),
"region_name": mock_athena_conn_without_token.extra_dejson.get("region_name"),
"s3_staging_dir": mock_athena_conn_without_token.extra_dejson.get("s3_staging_dir"),
"schema": mock_athena_conn_without_token.extra_dejson.get("schema"),
}


def test_athena_profile_args_overrides(
mock_athena_conn: Connection,
) -> None:
Expand Down

0 comments on commit df09ea9

Please sign in to comment.