Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull the database name for the Postgres connection only from the connection definition #117

Merged
merged 10 commits into from
Oct 10, 2023
14 changes: 12 additions & 2 deletions great_expectations_provider/operators/great_expectations.py
Expand Up @@ -244,9 +244,9 @@ def make_connection_configuration(self) -> Dict[str, str]:
raise ValueError(f"Connections does not exist in Airflow for conn_id: {self.conn_id}")
self.schema = self.schema or self.conn.schema
conn_type = self.conn.conn_type
if conn_type in ("redshift", "postgres", "mysql", "mssql"):
if conn_type in ("redshift", "mysql", "mssql"):
odbc_connector = ""
if conn_type in ("redshift", "postgres"):
if conn_type in ("redshift"):
odbc_connector = "postgresql+psycopg2"
database_name = self.schema
elif conn_type == "mysql":
Expand All @@ -263,6 +263,16 @@ def make_connection_configuration(self) -> Dict[str, str]:
f"{odbc_connector}://{self.conn.login}:{self.conn.password}@"
f"{self.conn.host}:{self.conn.port}/{database_name}{driver}"
)
elif conn_type == "postgres":
# the schema parameter in the postgres connection is the database name
if self.conn.schema:
postgres_database = self.conn.schema
odbc_connector = "postgresql+psycopg2"
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{postgres_database}" # noqa
else:
raise ValueError(
"Specify the name of the database in the schema parameter of the Postgres connection. See: https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/connections/postgres.html" # noqa
)
elif conn_type == "snowflake":
try:
return self.build_snowflake_connection_config_from_hook()
Expand Down
110 changes: 61 additions & 49 deletions tests/operators/test_great_expectations.py
Expand Up @@ -12,6 +12,7 @@

import logging
import os
import tempfile
import unittest.mock as mock
from pathlib import Path

Expand Down Expand Up @@ -859,57 +860,68 @@ def test_great_expectations_operator__make_connection_string_snowflake(mocker):


def test_great_expectations_operator__make_connection_string_snowflake_pkey(mocker):
private_key_bytes = b"secret"
test_conn_conf = {
"url": URL.create(
drivername="snowflake",
username="user",
password="",
host="account.region-east-1",
database="database/schema",
query={"role": "role", "warehouse": "warehouse", "authenticator": "snowflake", "application": "AIRFLOW"},
).render_as_string(hide_password=False),
"connect_args": {"private_key": private_key_bytes},
}
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_runtime_data_asset",
conn_id="snowflake_default",
query_to_validate="SELECT * FROM db;",
expectation_suite_name="suite",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
schema="schema",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
"extra__snowflake__private_key_file": "/path/to/key.p8",
},
)
operator.conn_type = operator.conn.conn_type
# create a temp key file
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
private_key_bytes = b"fake_key"
temp_file.write(private_key_bytes)
temp_file.flush()
test_conn_conf = {
"url": URL.create(
drivername="snowflake",
username="user",
password="",
host="account.region-east-1",
database="database/schema",
query={
"role": "role",
"warehouse": "warehouse",
"authenticator": "snowflake",
"application": "AIRFLOW",
},
).render_as_string(hide_password=False),
"connect_args": {"private_key": private_key_bytes},
}
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_runtime_data_asset",
conn_id="snowflake_default",
query_to_validate="SELECT * FROM db;",
expectation_suite_name="suite",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
schema="schema",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
"extra__snowflake__private_key_file": temp_file.name,
},
)
operator.conn_type = operator.conn.conn_type

mocker.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_connection", return_value=operator.conn
)
mocker.patch("great_expectations_provider.operators.great_expectations.Path.read_bytes", return_value=b"dummy")
mocked_key = mock.MagicMock(default_backend())
mocked_key.private_bytes = mock.MagicMock(return_value=private_key_bytes)
mocker.patch(
"cryptography.hazmat.primitives.serialization.load_pem_private_key",
return_value=mocked_key,
)
mocker.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_connection", return_value=operator.conn
)
mocker.patch(
"great_expectations_provider.operators.great_expectations.Path.read_bytes", return_value=b"fake_key"
)
mocked_key = mock.MagicMock(default_backend())
mocked_key.private_bytes = mock.MagicMock(return_value=private_key_bytes)
mocker.patch(
"cryptography.hazmat.primitives.serialization.load_pem_private_key",
return_value=mocked_key,
)

assert operator.make_connection_configuration() == test_conn_conf
assert operator.make_connection_configuration() == test_conn_conf


def test_great_expectations_operator__make_connection_string_sqlite():
Expand Down