Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add schema parameter and data_asset_name parsing #75

Merged
merged 4 commits into from Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 22 additions & 4 deletions great_expectations_provider/operators/great_expectations.py
Expand Up @@ -79,6 +79,8 @@ class GreatExpectationsOperator(BaseOperator):

:param run_name: Identifies the validation run (defaults to timestamp if not specified)
:type run_name: Optional[str]
:param conn: An Airflow Connection or dict to create a Connection
:type conn: Optional[Union[Connection, Dict]
:param conn_id: The name of a connection in Airflow
:type conn_id: Optional[str]
:param execution_engine: The execution engine to use when running Great Expectations
Expand All @@ -94,7 +96,7 @@ class GreatExpectationsOperator(BaseOperator):
:type data_context_config: Optional[DataContextConfig]
:param dataframe_to_validate: A pandas dataframe to validate
:type dataframe_to_validate: Optional[str]
:param query_to_validate: A SQL query to validate
:param query_to_validate: A SQL query to validate`
denimalpaca marked this conversation as resolved.
Show resolved Hide resolved
:type query_to_validate: Optional[str]
:param checkpoint_name: A Checkpoint name to use for validation
:type checkpoint_name: Optional[str]
Expand All @@ -111,6 +113,8 @@ class GreatExpectationsOperator(BaseOperator):
:type return_json_dict: bool
:param use_open_lineage: If True (default), creates an OpenLineage action if an OpenLineage environment is found
:type use_open_lineage: bool
:param schema: If provided, overwrites the default schema provded by the connection
:type schema: Optional[str]
"""

ui_color = "#AFEEEE"
Expand All @@ -129,6 +133,7 @@ class GreatExpectationsOperator(BaseOperator):
def __init__(
self,
run_name: Optional[str] = None,
conn: Optional[Union[Connection, Dict[str, Any]]] = None,
denimalpaca marked this conversation as resolved.
Show resolved Hide resolved
conn_id: Optional[str] = None,
execution_engine: Optional[str] = None,
expectation_suite_name: Optional[str] = None,
Expand All @@ -144,13 +149,15 @@ def __init__(
fail_task_on_validation_failure: bool = True,
return_json_dict: bool = False,
use_open_lineage: bool = True,
schema: Optional[str] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)

self.data_asset_name: Optional[str] = data_asset_name
self.run_name: Optional[str] = run_name
self.conn: Optional[Union[Connection, Dict[str, Any]]] = Connection(**conn) if isinstance(conn, Dict) else conn
self.conn_id: Optional[str] = conn_id
self.execution_engine: Optional[str] = execution_engine
self.expectation_suite_name: Optional[str] = expectation_suite_name
Expand All @@ -170,6 +177,7 @@ def __init__(
self.is_dataframe = True if self.dataframe_to_validate is not None else False
self.datasource: Optional[Datasource] = None
self.batch_request: Optional[BatchRequestBase] = None
self.schema = schema

if self.is_dataframe and self.query_to_validate:
raise ValueError(
Expand Down Expand Up @@ -213,11 +221,21 @@ def __init__(
if isinstance(self.checkpoint_config, CheckpointConfig):
self.checkpoint_config = deep_filter_properties_iterable(properties=self.checkpoint_config.to_dict())

# If a schema is passed as part of the data_asset_name, use that schema
if self.data_asset_name and "." in self.data_asset_name:
# Assume data_asset_name is in the form "SCHEMA.TABLE"
# Schema parameter always takes priority
asset_list = self.data_asset_name.split(".")
self.schema = self.schema or asset_list[0]
# Update data_asset_name to be only the table
self.data_asset_name = asset_list[1]

def make_connection_string(self) -> str:
"""Builds connection strings based off existing Airflow connections. Only supports necessary extras."""
uri_string = ""
if not self.conn:
raise ValueError(f"Connections does not exist in Airflow for conn_id: {self.conn_id}")
schema = self.schema or self.conn.schema
conn_type = self.conn.conn_type
if conn_type in ("redshift", "postgres", "mysql", "mssql"):
odbc_connector = ""
Expand All @@ -227,11 +245,11 @@ def make_connection_string(self) -> str:
odbc_connector = "mysql"
else:
odbc_connector = "mssql+pyodbc"
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{self.conn.schema}" # noqa
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{schema}" # noqa
elif conn_type == "snowflake":
uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{self.conn.schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa
uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa
elif conn_type == "gcpbigquery":
uri_string = f"{self.conn.host}{self.conn.schema}"
uri_string = f"{self.conn.host}{schema}"
elif conn_type == "sqlite":
uri_string = f"sqlite:///{self.conn.host}"
# TODO: Add Athena and Trino support if possible
Expand Down
63 changes: 63 additions & 0 deletions tests/operators/test_great_expectations.py
Expand Up @@ -862,6 +862,69 @@ def test_great_expectations_operator__make_connection_string_sqlite():
assert operator.make_connection_string() == test_conn_str


def test_great_expectations_operator__make_connection_string_schema_parameter():
test_conn_str = (
"snowflake://user:password@account.region-east-1/database/test_schema_parameter?warehouse=warehouse&role=role"
)
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_schema.test_table",
conn_id="snowflake_default",
expectation_suite_name="suite",
schema="test_schema_parameter",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
schema="schema",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
},
)
operator.conn_type = operator.conn.conn_type
assert operator.make_connection_string() == test_conn_str


def test_great_expectations_operator__make_connection_string_data_asset_name_schema_parse():
test_conn_str = (
"snowflake://user:password@account.region-east-1/database/test_schema?warehouse=warehouse&role=role"
)
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_schema.test_table",
conn_id="snowflake_default",
expectation_suite_name="suite",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
},
)
operator.conn_type = operator.conn.conn_type
assert operator.make_connection_string() == test_conn_str
assert operator.data_asset_name == "test_table"


def test_great_expectations_operator__make_connection_string_raise_error():
operator = GreatExpectationsOperator(
task_id="task_id",
Expand Down