Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix do_xcom_push=False bug in SnowflakeOperator #29599

Merged
merged 3 commits into from
Feb 22, 2023

Conversation

fritz-astronomer
Copy link
Contributor

@fritz-astronomer fritz-astronomer commented Feb 17, 2023

closes: #29593

Adds a guard check to short-circuit out of _process_output in the event of no output (such as do_xcom_push=False)

@boring-cyborg boring-cyborg bot added area:providers provider:snowflake Issues related to Snowflake provider labels Feb 17, 2023
@fritz-astronomer
Copy link
Contributor Author

Test DAG to fix (see linked issue for Test DAG Demonstrating error):

from __future__ import annotations

import os
from datetime import datetime
from typing import Optional, Any, Sequence

from airflow import DAG
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator

os.environ["AIRFLOW_CONN_SNOWFLAKE"] = "snowflake://.............."


class SnowflakePatchOperator(SnowflakeOperator):
    def _process_output(
        self,
        results: Optional[list[Any]],
        descriptions: list[Sequence[Sequence] | None]
    ) -> list[Any]:
        # Handle do_xcom_push=False
        if results is None or results == [None]:
            return [None]
        validated_descriptions: list[Sequence[Sequence]] = []
        for idx, description in enumerate(descriptions):
            if not description:
                raise RuntimeError(
                    f"The query did not return descriptions of the cursor for query number {idx}. "
                    "Cannot return values in a form of dictionary for that query."
                )
            validated_descriptions.append(description)
        returned_results = []
        for result_id, result_list in enumerate(results):
            current_processed_result = []
            for row in result_list:
                dict_result: dict[Any, Any] = {}
                for idx, description in enumerate(validated_descriptions[result_id]):
                    dict_result[description[0]] = row[idx]
                current_processed_result.append(dict_result)
            returned_results.append(current_processed_result)
        return returned_results



with DAG('snowflake_test', schedule=None, start_date=datetime(2023, 1, 1)):
    SnowflakePatchOperator(
        task_id='snowflake_test',
        snowflake_conn_id="snowflake",
        sql="select 1;",
        do_xcom_push=False
    )

Copy link
Member

@hussein-awala hussein-awala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this method is defined and called in the super class SQLExecuteQueryOperator, I prefer to add the check in this bloc to avoid adding it in all the sub classes, and please add a unit test to check if everything works as expected.

@fritz-astronomer
Copy link
Contributor Author

fritz-astronomer commented Feb 17, 2023

@hussein-awala - the bug is specifically in the SnowflakeProvider - the base class just does a no-op with _process_output and isn't affected, I was hesitant to add it there in case the fix interferes with other providers some how by skipping a _process_output that does do something with an empty result set 🤔

--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -265,6 +265,9 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             return_last=self.return_last,
             **extra_kwargs,
         )
+        # Handle do_xcom_push=False
+        if output is None or output == [None]:
+            return []
         if return_single_query_results(self.sql, self.return_last, self.split_statements):
             # For simplicity, we pass always list as input to _process_output, regardless if
             # single query results are going to be returned, and we return the first element

--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -108,9 +108,6 @@ class SnowflakeOperator(SQLExecuteQueryOperator):
         results: Optional[list[Any]],
         descriptions: list[Sequence[Sequence] | None]
     ) -> list[Any]:
-        # Handle do_xcom_push=False
-        if results is None or results == [None]:
-            return []
         validated_descriptions: list[Sequence[Sequence]] = []
         for idx, description in enumerate(descriptions):
             if not description:

@Taragolis
Copy link
Contributor

Unfortunetly Statics Check and also required tests for avoid regression.

@hussein-awala
Copy link
Member

@hussein-awala - the bug is specifically in the SnowflakeProvider - the base class just does a no-op with _process_output and isn't affected, I was hesitant to add it there in case the fix interferes with other providers some how by skipping a _process_output that does do something with an empty result set thinking

--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -265,6 +265,9 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             return_last=self.return_last,
             **extra_kwargs,
         )
+        # Handle do_xcom_push=False
+        if output is None or output == [None]:
+            return []
         if return_single_query_results(self.sql, self.return_last, self.split_statements):
             # For simplicity, we pass always list as input to _process_output, regardless if
             # single query results are going to be returned, and we return the first element

--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -108,9 +108,6 @@ class SnowflakeOperator(SQLExecuteQueryOperator):
         results: Optional[list[Any]],
         descriptions: list[Sequence[Sequence] | None]
     ) -> list[Any]:
-        # Handle do_xcom_push=False
-        if results is None or results == [None]:
-            return []
         validated_descriptions: list[Sequence[Sequence]] = []
         for idx, description in enumerate(descriptions):
             if not description:

I think adding a condition on the SQLExecuteQueryOperator without changing the snowflake operator is enough to solve the problem, since we don't need to process the result because we don't want to push it as a xcom:

--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -265,6 +265,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             return_last=self.return_last,
             **extra_kwargs,
         )
+        if not self.do_xcom_push:
+            return None
         if return_single_query_results(self.sql, self.return_last, self.split_statements):
             # For simplicity, we pass always list as input to _process_output, regardless if
             # single query results are going to be returned, and we return the first element

What do you think?

@fritz-astronomer
Copy link
Contributor Author

Oh! Brilliant! Just short circuit and skip all the rest. That makes sense. Not sure why I was worried that method might be needed for side-effects other than xcom

@fritz-astronomer
Copy link
Contributor Author

Did a quick rebase just to clean up commits.
Tests are added and passing in breeze ✅

@bolkedebruin bolkedebruin merged commit 19f1e7c into apache:main Feb 22, 2023
@fritz-astronomer fritz-astronomer deleted the snowflake_xcom_fix branch February 22, 2023 16:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:providers provider:snowflake Issues related to Snowflake provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot disable XCom push in SnowflakeOperator
5 participants