Skip to content

Commit

Permalink
Use only public AwsHook's methods during IAM authorization (#25424)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Aug 2, 2022
1 parent d290002 commit 4eb0a41
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 57 deletions.
37 changes: 19 additions & 18 deletions airflow/providers/postgres/hooks/postgres.py
Expand Up @@ -174,37 +174,38 @@ def get_iam_token(self, conn: Connection) -> Tuple[str, str, int]:
or Redshift. Port is required. If none is provided, default is used for
each service
"""
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
try:
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
except ImportError:
from airflow.exceptions import AirflowException

raise AirflowException(
"apache-airflow-providers-amazon not installed, run: "
"pip install 'apache-airflow-providers-postgres[amazon]'."
)

redshift = conn.extra_dejson.get('redshift', False)
aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
aws_hook = AwsBaseHook(aws_conn_id, client_type='rds')
login = conn.login
if conn.port is None:
port = 5439 if redshift else 5432
else:
port = conn.port
if redshift:
if conn.extra_dejson.get('redshift', False):
port = conn.port or 5439
# Pull the custer-identifier from the beginning of the Redshift URL
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
cluster_identifier = conn.extra_dejson.get('cluster-identifier', conn.host.split('.')[0])
session, endpoint_url = aws_hook._get_credentials(region_name=None)
client = session.client(
"redshift",
endpoint_url=endpoint_url,
config=aws_hook.config,
verify=aws_hook.verify,
)
cluster_creds = client.get_cluster_credentials(
DbUser=conn.login,
redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="redshift").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
DbUser=login,
DbName=self.schema or conn.schema,
ClusterIdentifier=cluster_identifier,
AutoCreate=False,
)
token = cluster_creds['DbPassword']
login = cluster_creds['DbUser']
else:
token = aws_hook.conn.generate_db_auth_token(conn.host, port, conn.login)
port = conn.port or 5432
rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="rds").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token
token = rds_client.generate_db_auth_token(conn.host, port, conn.login)
return login, token, port

def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> Optional[List[str]]:
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/postgres/provider.yaml
Expand Up @@ -60,7 +60,11 @@ hooks:
python-modules:
- airflow.providers.postgres.hooks.postgres


connection-types:
- hook-class-name: airflow.providers.postgres.hooks.postgres.PostgresHook
connection-type: postgres

additional-extras:
- name: amazon
dependencies:
- apache-airflow-providers-amazon>=2.6.0
34 changes: 34 additions & 0 deletions docs/apache-airflow-providers-postgres/connections/postgres.rst
Expand Up @@ -74,6 +74,40 @@ Extra (optional)
"sslkey": "/tmp/client-key.pem"
}
The following extra parameters use for additional Hook configuration:

* ``iam`` - If set to ``True`` than use AWS IAM database authentication for
`Amazon RDS <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html>`__,
`Amazon Aurora <https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.html>`__
or `Amazon Redshift <https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`__.
* ``aws_conn_id`` - AWS Connection ID which use for authentication via AWS IAM,
if not specified then **aws_conn_id** is used.
* ``redshift`` - Used when AWS IAM database authentication enabled.
If set to ``True`` than authenticate to Amazon Redshift Cluster, otherwise to Amazon RDS or Amazon Aurora.
* ``cluster-identifier`` - The unique identifier of the Amazon Redshift Cluster that contains the database
for which you are requesting credentials. This parameter is case sensitive.
If not specified than hostname from **Connection Host** is used.

Example "extras" field (Amazon RDS PostgreSQL or Amazon Aurora PostgreSQL):

.. code-block:: json
{
"iam": true,
"aws_conn_id": "aws_awesome_rds_conn"
}
Example "extras" field (Amazon Redshift):

.. code-block:: json
{
"iam": true,
"aws_conn_id": "aws_awesome_redshift_conn",
"redshift": "/tmp/server-ca.pem",
"cluster-identifier": "awesome-redshift-identifier"
}
When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it
following the standard syntax of DB connections, where extras are passed as parameters
of the URI (note that all components of the URI should be URL-encoded).
Expand Down
143 changes: 105 additions & 38 deletions tests/providers/postgres/hooks/test_postgres.py
Expand Up @@ -26,12 +26,12 @@

from airflow.models import Connection
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.types import NOTSET


class TestPostgresHookConn(unittest.TestCase):
def setUp(self):
super().setUp()

class TestPostgresHookConn:
@pytest.fixture(autouse=True)
def setup(self):
self.connection = Connection(login='login', password='password', host='host', schema='schema')

class UnitTestPostgresHook(PostgresHook):
Expand Down Expand Up @@ -63,10 +63,7 @@ def test_get_uri(self, mock_connect):
self.connection.conn_type = 'postgres'
self.db_hook.get_conn()
assert mock_connect.call_count == 1

self.assertEqual(
self.db_hook.get_uri(), "postgresql://login:password@host/schema?client_encoding=utf-8"
)
assert self.db_hook.get_uri() == "postgresql://login:password@host/schema?client_encoding=utf-8"

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn_cursor(self, mock_connect):
Expand Down Expand Up @@ -106,13 +103,41 @@ def test_get_conn_from_connection_with_schema(self, mock_connect):
)

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
@mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type')
def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect):
self.connection.extra = '{"iam":true}'
mock_client.return_value.generate_db_auth_token.return_value = 'aws_token'
@mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook')
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [65432, 5432, None])
def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class, mock_connect, aws_conn_id, port):
mock_conn_extra = {"iam": True}
if aws_conn_id is not NOTSET:
mock_conn_extra["aws_conn_id"] = aws_conn_id
self.connection.extra = json.dumps(mock_conn_extra)
self.connection.port = port
mock_db_token = "aws_token"

# Mock AWS Connection
mock_aws_hook_instance = mock_aws_hook_class.return_value
mock_client = mock.MagicMock()
mock_client.generate_db_auth_token.return_value = mock_db_token
type(mock_aws_hook_instance).conn = mock.PropertyMock(return_value=mock_client)

self.db_hook.get_conn()
# Check AwsHook initialization
mock_aws_hook_class.assert_called_once_with(
# If aws_conn_id not set than fallback to aws_default
aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default",
client_type="rds",
)
# Check boto3 'rds' client method `generate_db_auth_token` call args
mock_client.generate_db_auth_token.assert_called_once_with(
self.connection.host, (port or 5432), self.connection.login
)
# Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
user='login', password='aws_token', host='host', dbname='schema', port=5432
user=self.connection.login,
password=mock_db_token,
host=self.connection.host,
dbname=self.connection.schema,
port=(port or 5432),
)

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
Expand All @@ -124,39 +149,81 @@ def test_get_conn_extra(self, mock_connect):
)

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn_rds_iam_redshift(self, mock_connect):
self.connection.extra = '{"iam":true, "redshift":true, "cluster-identifier": "different-identifier"}'
self.connection.host = 'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com'
login = f'IAM:{self.connection.login}'

mock_session = mock.Mock()
mock_get_cluster_credentials = mock_session.client.return_value.get_cluster_credentials
mock_get_cluster_credentials.return_value = {'DbPassword': 'aws_token', 'DbUser': login}

aws_get_credentials_patcher = mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook._get_credentials",
return_value=(mock_session, None),
@mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook')
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [5432, 5439, None])
@pytest.mark.parametrize(
"host,conn_cluster_identifier,expected_cluster_identifier",
[
(
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com',
NOTSET,
'cluster-identifier',
),
(
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com',
'different-identifier',
'different-identifier',
),
],
)
def test_get_conn_rds_iam_redshift(
self,
mock_aws_hook_class,
mock_connect,
aws_conn_id,
port,
host,
conn_cluster_identifier,
expected_cluster_identifier,
):
mock_conn_extra = {
"iam": True,
"redshift": True,
}
if aws_conn_id is not NOTSET:
mock_conn_extra["aws_conn_id"] = aws_conn_id
if conn_cluster_identifier is not NOTSET:
mock_conn_extra["cluster-identifier"] = conn_cluster_identifier

self.connection.extra = json.dumps(mock_conn_extra)
self.connection.host = host
self.connection.port = port
mock_db_user = f'IAM:{self.connection.login}'
mock_db_pass = "aws_token"

# Mock AWS Connection
mock_aws_hook_instance = mock_aws_hook_class.return_value
mock_client = mock.MagicMock()
mock_client.get_cluster_credentials.return_value = {
'DbPassword': mock_db_pass,
'DbUser': mock_db_user,
}
type(mock_aws_hook_instance).conn = mock.PropertyMock(return_value=mock_client)

self.db_hook.get_conn()
# Check AwsHook initialization
mock_aws_hook_class.assert_called_once_with(
# If aws_conn_id not set than fallback to aws_default
aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default",
client_type="redshift",
)
get_cluster_credentials_call = mock.call(
# Check boto3 'redshift' client method `get_cluster_credentials` call args
mock_client.get_cluster_credentials.assert_called_once_with(
DbUser=self.connection.login,
DbName=self.connection.schema,
ClusterIdentifier="different-identifier",
ClusterIdentifier=expected_cluster_identifier,
AutoCreate=False,
)

with aws_get_credentials_patcher:
self.db_hook.get_conn()
assert mock_get_cluster_credentials.mock_calls == [get_cluster_credentials_call]
# Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
user=login, password='aws_token', host=self.connection.host, dbname='schema', port=5439
user=mock_db_user,
password=mock_db_pass,
host=host,
dbname=self.connection.schema,
port=(port or 5439),
)

# Verify that the connection object has not been mutated.
mock_get_cluster_credentials.reset_mock()
with aws_get_credentials_patcher:
self.db_hook.get_conn()
assert mock_get_cluster_credentials.mock_calls == [get_cluster_credentials_call]

def test_get_uri_from_connection_without_schema_override(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
Expand Down

0 comments on commit 4eb0a41

Please sign in to comment.