From 2ab78ec441a748ae4d99e429fe336b80a601d7b1 Mon Sep 17 00:00:00 2001 From: Marcin Molak Date: Mon, 31 Jul 2023 21:21:00 +0200 Subject: [PATCH] Fix connection parameters of `SnowflakeValueCheckOperator` (#32605) --- .../snowflake/operators/snowflake.py | 65 +++++++++-------- .../snowflake/operators/test_snowflake.py | 73 +++++++++++++++++++ 2 files changed, 106 insertions(+), 32 deletions(-) diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 1de82218ff34c..8f29eefd5b936 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -212,18 +212,18 @@ def __init__( session_parameters: dict | None = None, **kwargs, ) -> None: + if any([warehouse, database, role, schema, authenticator, session_parameters]): + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = { + "warehouse": warehouse, + "database": database, + "role": role, + "schema": schema, + "authenticator": authenticator, + "session_parameters": session_parameters, + **hook_params, + } super().__init__(sql=sql, parameters=parameters, conn_id=snowflake_conn_id, **kwargs) - self.snowflake_conn_id = snowflake_conn_id - self.sql = sql - self.autocommit = autocommit - self.do_xcom_push = do_xcom_push - self.parameters = parameters - self.warehouse = warehouse - self.database = database - self.role = role - self.schema = schema - self.authenticator = authenticator - self.session_parameters = session_parameters self.query_ids: list[str] = [] @@ -277,20 +277,20 @@ def __init__( session_parameters: dict | None = None, **kwargs, ) -> None: + if any([warehouse, database, role, schema, authenticator, session_parameters]): + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = { + "warehouse": warehouse, + "database": database, + "role": role, + "schema": schema, + "authenticator": authenticator, + "session_parameters": session_parameters, + **hook_params, + } super().__init__( sql=sql, pass_value=pass_value, tolerance=tolerance, conn_id=snowflake_conn_id, **kwargs ) - self.snowflake_conn_id = snowflake_conn_id - self.sql = sql - self.autocommit = autocommit - self.do_xcom_push = do_xcom_push - self.parameters = parameters - self.warehouse = warehouse - self.database = database - self.role = role - self.schema = schema - self.authenticator = authenticator - self.session_parameters = session_parameters self.query_ids: list[str] = [] @@ -352,6 +352,17 @@ def __init__( session_parameters: dict | None = None, **kwargs, ) -> None: + if any([warehouse, database, role, schema, authenticator, session_parameters]): + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = { + "warehouse": warehouse, + "database": database, + "role": role, + "schema": schema, + "authenticator": authenticator, + "session_parameters": session_parameters, + **hook_params, + } super().__init__( table=table, metrics_thresholds=metrics_thresholds, @@ -360,16 +371,6 @@ def __init__( conn_id=snowflake_conn_id, **kwargs, ) - self.snowflake_conn_id = snowflake_conn_id - self.autocommit = autocommit - self.do_xcom_push = do_xcom_push - self.parameters = parameters - self.warehouse = warehouse - self.database = database - self.role = role - self.schema = schema - self.authenticator = authenticator - self.session_parameters = session_parameters self.query_ids: list[str] = [] diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 41cbfe6717464..ea9d8333f3271 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -72,6 +72,36 @@ def test_snowflake_operator(self, mock_get_db_hook): operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) +class TestSnowflakeOperatorForParams: + @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__") + def test_overwrite_params(self, mock_base_op): + sql = "Select * from test_table" + SnowflakeOperator( + sql=sql, + task_id="snowflake_params_check", + snowflake_conn_id="snowflake_default", + warehouse="test_warehouse", + database="test_database", + role="test_role", + schema="test_schema", + authenticator="oath", + session_parameters={"QUERY_TAG": "test_tag"}, + ) + mock_base_op.assert_called_once_with( + conn_id="snowflake_default", + task_id="snowflake_params_check", + hook_params={ + "warehouse": "test_warehouse", + "database": "test_database", + "role": "test_role", + "schema": "test_schema", + "authenticator": "oath", + "session_parameters": {"QUERY_TAG": "test_tag"}, + }, + default_args={}, + ) + + @pytest.mark.parametrize( "operator_class, kwargs", [ @@ -93,6 +123,49 @@ def test_get_db_hook( mock_get_db_hook.assert_called_once() +@pytest.mark.parametrize( + "operator_class, kwargs", + [ + (SnowflakeCheckOperator, dict(sql="Select * from test_table")), + (SnowflakeValueCheckOperator, dict(sql="Select * from test_table", pass_value=95)), + (SnowflakeIntervalCheckOperator, dict(table="test-table-id", metrics_thresholds={"COUNT(*)": 1.5})), + ], +) +class TestSnowflakeCheckOperatorsForParams: + @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__") + def test_overwrite_params( + self, + mock_base_op, + operator_class, + kwargs, + ): + operator_class( + task_id="snowflake_params_check", + snowflake_conn_id="snowflake_default", + warehouse="test_warehouse", + database="test_database", + role="test_role", + schema="test_schema", + authenticator="oath", + session_parameters={"QUERY_TAG": "test_tag"}, + **kwargs, + ) + mock_base_op.assert_called_once_with( + conn_id="snowflake_default", + database=None, + task_id="snowflake_params_check", + hook_params={ + "warehouse": "test_warehouse", + "database": "test_database", + "role": "test_role", + "schema": "test_schema", + "authenticator": "oath", + "session_parameters": {"QUERY_TAG": "test_tag"}, + }, + default_args={}, + ) + + def create_context(task, dag=None): if dag is None: dag = DAG(dag_id="dag")