Skip to content

Commit

Permalink
Fix connection parameters of SnowflakeValueCheckOperator (#32605)
Browse files Browse the repository at this point in the history
  • Loading branch information
frodo2000 committed Jul 31, 2023
1 parent 2b0d88e commit 2ab78ec
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 32 deletions.
65 changes: 33 additions & 32 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,18 @@ def __init__(
session_parameters: dict | None = None,
**kwargs,
) -> None:
if any([warehouse, database, role, schema, authenticator, session_parameters]):
hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {
"warehouse": warehouse,
"database": database,
"role": role,
"schema": schema,
"authenticator": authenticator,
"session_parameters": session_parameters,
**hook_params,
}
super().__init__(sql=sql, parameters=parameters, conn_id=snowflake_conn_id, **kwargs)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: list[str] = []


Expand Down Expand Up @@ -277,20 +277,20 @@ def __init__(
session_parameters: dict | None = None,
**kwargs,
) -> None:
if any([warehouse, database, role, schema, authenticator, session_parameters]):
hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {
"warehouse": warehouse,
"database": database,
"role": role,
"schema": schema,
"authenticator": authenticator,
"session_parameters": session_parameters,
**hook_params,
}
super().__init__(
sql=sql, pass_value=pass_value, tolerance=tolerance, conn_id=snowflake_conn_id, **kwargs
)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: list[str] = []


Expand Down Expand Up @@ -352,6 +352,17 @@ def __init__(
session_parameters: dict | None = None,
**kwargs,
) -> None:
if any([warehouse, database, role, schema, authenticator, session_parameters]):
hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {
"warehouse": warehouse,
"database": database,
"role": role,
"schema": schema,
"authenticator": authenticator,
"session_parameters": session_parameters,
**hook_params,
}
super().__init__(
table=table,
metrics_thresholds=metrics_thresholds,
Expand All @@ -360,16 +371,6 @@ def __init__(
conn_id=snowflake_conn_id,
**kwargs,
)
self.snowflake_conn_id = snowflake_conn_id
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: list[str] = []


Expand Down
73 changes: 73 additions & 0 deletions tests/providers/snowflake/operators/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,36 @@ def test_snowflake_operator(self, mock_get_db_hook):
operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)


class TestSnowflakeOperatorForParams:
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
def test_overwrite_params(self, mock_base_op):
sql = "Select * from test_table"
SnowflakeOperator(
sql=sql,
task_id="snowflake_params_check",
snowflake_conn_id="snowflake_default",
warehouse="test_warehouse",
database="test_database",
role="test_role",
schema="test_schema",
authenticator="oath",
session_parameters={"QUERY_TAG": "test_tag"},
)
mock_base_op.assert_called_once_with(
conn_id="snowflake_default",
task_id="snowflake_params_check",
hook_params={
"warehouse": "test_warehouse",
"database": "test_database",
"role": "test_role",
"schema": "test_schema",
"authenticator": "oath",
"session_parameters": {"QUERY_TAG": "test_tag"},
},
default_args={},
)


@pytest.mark.parametrize(
"operator_class, kwargs",
[
Expand All @@ -93,6 +123,49 @@ def test_get_db_hook(
mock_get_db_hook.assert_called_once()


@pytest.mark.parametrize(
"operator_class, kwargs",
[
(SnowflakeCheckOperator, dict(sql="Select * from test_table")),
(SnowflakeValueCheckOperator, dict(sql="Select * from test_table", pass_value=95)),
(SnowflakeIntervalCheckOperator, dict(table="test-table-id", metrics_thresholds={"COUNT(*)": 1.5})),
],
)
class TestSnowflakeCheckOperatorsForParams:
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
def test_overwrite_params(
self,
mock_base_op,
operator_class,
kwargs,
):
operator_class(
task_id="snowflake_params_check",
snowflake_conn_id="snowflake_default",
warehouse="test_warehouse",
database="test_database",
role="test_role",
schema="test_schema",
authenticator="oath",
session_parameters={"QUERY_TAG": "test_tag"},
**kwargs,
)
mock_base_op.assert_called_once_with(
conn_id="snowflake_default",
database=None,
task_id="snowflake_params_check",
hook_params={
"warehouse": "test_warehouse",
"database": "test_database",
"role": "test_role",
"schema": "test_schema",
"authenticator": "oath",
"session_parameters": {"QUERY_TAG": "test_tag"},
},
default_args={},
)


def create_context(task, dag=None):
if dag is None:
dag = DAG(dag_id="dag")
Expand Down

0 comments on commit 2ab78ec

Please sign in to comment.