Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,13 @@ def get_aws_iam_token(self, conn: Connection) -> tuple[str, str, int]:
port = conn.port or 5439
# Pull the custer-identifier from the beginning of the Redshift URL
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
cluster_identifier = conn.extra_dejson.get(
"cluster-identifier", cast("str", conn.host).split(".")[0]
)
cluster_identifier = conn.extra_dejson.get("cluster-identifier")
if cluster_identifier is None:
if not conn.host:
raise ValueError(
"connection host is required for AWS IAM token when cluster-identifier is not set in extras."
)
cluster_identifier = conn.host.split(".")[0]
redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="redshift").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift/client/get_cluster_credentials.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
Expand All @@ -501,7 +505,13 @@ def get_aws_iam_token(self, conn: Connection) -> tuple[str, str, int]:
# Pull the workgroup-name from the query params/extras, if not there then pull it from the
# beginning of the Redshift URL
# ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns workgroup-name
workgroup_name = conn.extra_dejson.get("workgroup-name", cast("str", conn.host).split(".")[0])
workgroup_name = conn.extra_dejson.get("workgroup-name")
if workgroup_name is None:
if not conn.host:
raise ValueError(
"connection host is required for AWS IAM token when workgroup-name is not set in extras."
)
workgroup_name = conn.host.split(".")[0]
redshift_serverless_client = AwsBaseHook(
aws_conn_id=aws_conn_id, client_type="redshift-serverless"
).conn
Expand Down
108 changes: 62 additions & 46 deletions providers/postgres/tests/unit/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,22 @@ def test_get_conn_extra(self, mock_connect):
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [5432, 5439, None])
@pytest.mark.parametrize(
("host", "conn_cluster_identifier", "expected_cluster_identifier"),
("host", "conn_cluster_identifier", "expected_cluster_identifier", "raises_exception"),
[
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
NOTSET,
"cluster-identifier",
False,
),
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
"different-identifier",
"different-identifier",
False,
),
(None, NOTSET, None, True),
(None, "cluster-identifier", "cluster-identifier", False),
],
)
def test_get_conn_rds_iam_redshift(
Expand All @@ -327,6 +331,7 @@ def test_get_conn_rds_iam_redshift(
host,
conn_cluster_identifier,
expected_cluster_identifier,
raises_exception,
):
mock_aws_hook_class = mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")

Expand All @@ -353,45 +358,52 @@ def test_get_conn_rds_iam_redshift(
"DbUser": mock_db_user,
}
type(mock_aws_hook_instance).conn = mocker.PropertyMock(return_value=mock_client)

self.db_hook.get_conn()
# Check AwsHook initialization
mock_aws_hook_class.assert_called_once_with(
# If aws_conn_id not set than fallback to aws_default
aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default",
client_type="redshift",
)
# Check boto3 'redshift' client method `get_cluster_credentials` call args
mock_client.get_cluster_credentials.assert_called_once_with(
DbUser=self.connection.login,
DbName=self.connection.schema,
ClusterIdentifier=expected_cluster_identifier,
AutoCreate=False,
)
# Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
user=mock_db_user,
password=mock_db_pass,
host=host,
dbname=self.connection.schema,
port=(port or 5439),
)
if raises_exception:
with pytest.raises(ValueError, match="connection host is required"):
self.db_hook.get_conn()
else:
self.db_hook.get_conn()
# Check AwsHook initialization
mock_aws_hook_class.assert_called_once_with(
# If aws_conn_id not set than fallback to aws_default
aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default",
client_type="redshift",
)
# Check boto3 'redshift' client method `get_cluster_credentials` call args
mock_client.get_cluster_credentials.assert_called_once_with(
DbUser=self.connection.login,
DbName=self.connection.schema,
ClusterIdentifier=expected_cluster_identifier,
AutoCreate=False,
)
# Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
user=mock_db_user,
password=mock_db_pass,
host=host,
dbname=self.connection.schema,
port=(port or 5439),
)

@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [5432, 5439, None])
@pytest.mark.parametrize(
("host", "conn_workgroup_name", "expected_workgroup_name"),
("host", "conn_workgroup_name", "expected_workgroup_name", "raises_exception"),
[
(
"serverless-workgroup.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
NOTSET,
"serverless-workgroup",
False,
),
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
"different-workgroup",
"different-workgroup",
False,
),
(None, NOTSET, None, True),
(None, "serverless-workgroup", "serverless-workgroup", False),
],
)
def test_get_conn_rds_iam_redshift_serverless(
Expand All @@ -403,6 +415,7 @@ def test_get_conn_rds_iam_redshift_serverless(
host,
conn_workgroup_name,
expected_workgroup_name,
raises_exception,
):
mock_aws_hook_class = mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")

Expand All @@ -429,27 +442,30 @@ def test_get_conn_rds_iam_redshift_serverless(
"dbUser": mock_db_user,
}
type(mock_aws_hook_instance).conn = mocker.PropertyMock(return_value=mock_client)

self.db_hook.get_conn()
# Check AwsHook initialization
mock_aws_hook_class.assert_called_once_with(
# If aws_conn_id not set than fallback to aws_default
aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default",
client_type="redshift-serverless",
)
# Check boto3 'redshift' client method `get_cluster_credentials` call args
mock_client.get_credentials.assert_called_once_with(
dbName=self.connection.schema,
workgroupName=expected_workgroup_name,
)
# Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
user=mock_db_user,
password=mock_db_pass,
host=host,
dbname=self.connection.schema,
port=(port or 5439),
)
if raises_exception:
with pytest.raises(ValueError, match="connection host is required"):
self.db_hook.get_conn()
else:
self.db_hook.get_conn()
# Check AwsHook initialization
mock_aws_hook_class.assert_called_once_with(
# If aws_conn_id not set than fallback to aws_default
aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default",
client_type="redshift-serverless",
)
# Check boto3 'redshift' client method `get_cluster_credentials` call args
mock_client.get_credentials.assert_called_once_with(
dbName=self.connection.schema,
workgroupName=expected_workgroup_name,
)
# Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
user=mock_db_user,
password=mock_db_pass,
host=host,
dbname=self.connection.schema,
port=(port or 5439),
)

def test_get_conn_azure_iam(self, mocker, mock_connect):
mock_azure_conn_id = "azure_conn1"
Expand Down