Skip to content

Commit

Permalink
Add custom handler param in SnowflakeOperator (#25983)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Aug 27, 2022
1 parent 0eec510 commit 9e12d48
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
from typing import Any, Callable, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union

from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
Expand Down Expand Up @@ -78,6 +78,8 @@ class SnowflakeOperator(BaseOperator):
through native Okta.
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:param handler: A Python callable that will act on cursor result.
By default, it will use ``fetchall``
"""

template_fields: Sequence[str] = ('sql',)
Expand All @@ -99,6 +101,7 @@ def __init__(
schema: Optional[str] = None,
authenticator: Optional[str] = None,
session_parameters: Optional[dict] = None,
handler: Optional[Callable] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -114,6 +117,7 @@ def __init__(
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: List[str] = []
self.handler = handler

def get_db_hook(self) -> SnowflakeHook:
return get_db_hook(self)
Expand All @@ -122,7 +126,8 @@ def execute(self, context: Any):
"""Run query on snowflake"""
self.log.info('Executing: %s', self.sql)
hook = self.get_db_hook()
execution_info = hook.run(self.sql, self.autocommit, self.parameters, fetch_all_handler)
handler = self.handler or fetch_all_handler
execution_info = hook.run(self.sql, self.autocommit, self.parameters, handler)
self.query_ids = hook.query_ids

if self.do_xcom_push:
Expand Down

0 comments on commit 9e12d48

Please sign in to comment.