Skip to content

Commit

Permalink
openlineage, snowflake: do not run external queries for Snowflake
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Apr 18, 2024
1 parent 1769ed0 commit 3b26f58
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 123 deletions.
9 changes: 9 additions & 0 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,14 @@ def get_openlineage_facets_on_start(self) -> OperatorLineage | None:

hook = self.get_db_hook()

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
try:
database_info = hook.get_openlineage_database_info(connection)
Expand All @@ -334,6 +342,7 @@ def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
database_info=database_info,
database=self.database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage
Expand Down
61 changes: 50 additions & 11 deletions airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.client.run import Dataset
from openlineage.common.sql import DbTableMeta, SqlMeta, parse

from airflow.providers.openlineage.extractors.base import OperatorLineage
Expand All @@ -40,7 +41,6 @@
from airflow.typing_compat import TypedDict

if TYPE_CHECKING:
from openlineage.client.run import Dataset
from sqlalchemy.engine import Engine

from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -104,6 +104,18 @@ class DatabaseInfo:
normalize_name_method: Callable[[str], str] = default_normalize_name_method


def from_table_meta(
table_meta: DbTableMeta, database: str | None, namespace: str, is_uppercase: bool
) -> Dataset:
if table_meta.database:
name = table_meta.qualified_name
elif database:
name = f"{database}.{table_meta.schema}.{table_meta.name}"
else:
name = f"{table_meta.schema}.{table_meta.name}"
return Dataset(namespace=namespace, name=name if not is_uppercase else name.upper())


class SQLParser:
"""Interface for openlineage-sql.
Expand All @@ -117,7 +129,7 @@ def __init__(self, dialect: str | None = None, default_schema: str | None = None

def parse(self, sql: list[str] | str) -> SqlMeta | None:
"""Parse a single or a list of SQL statements."""
return parse(sql=sql, dialect=self.dialect)
return parse(sql=sql, dialect=self.dialect, default_schema=self.default_schema)

def parse_table_schemas(
self,
Expand Down Expand Up @@ -156,6 +168,23 @@ def parse_table_schemas(
else None,
)

def get_metadata_from_parser(
self,
inputs: list[DbTableMeta],
outputs: list[DbTableMeta],
database_info: DatabaseInfo,
namespace: str = DEFAULT_NAMESPACE,
database: str | None = None,
) -> tuple[list[Dataset], ...]:
database = database if database else database_info.database
return [
from_table_meta(dataset, database, namespace, database_info.is_uppercase_names)
for dataset in inputs
], [
from_table_meta(dataset, database, namespace, database_info.is_uppercase_names)
for dataset in outputs
]

def attach_column_lineage(
self, datasets: list[Dataset], database: str | None, parse_result: SqlMeta
) -> None:
Expand Down Expand Up @@ -204,6 +233,7 @@ def generate_openlineage_metadata_from_sql(
database_info: DatabaseInfo,
database: str | None = None,
sqlalchemy_engine: Engine | None = None,
use_connection: bool = True,
) -> OperatorLineage:
"""Parse SQL statement(s) and generate OpenLineage metadata.
Expand Down Expand Up @@ -242,15 +272,24 @@ def generate_openlineage_metadata_from_sql(
)

namespace = self.create_namespace(database_info=database_info)
inputs, outputs = self.parse_table_schemas(
hook=hook,
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
sqlalchemy_engine=sqlalchemy_engine,
)
if use_connection:
inputs, outputs = self.parse_table_schemas(
hook=hook,
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
sqlalchemy_engine=sqlalchemy_engine,
)
else:
inputs, outputs = self.get_metadata_from_parser(
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
)

self.attach_column_lineage(outputs, database or database_info.database, parse_result)

Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,8 @@ def normalize_sql(sql: str | Iterable[str]):
sql = [stmt for stmt in sql.split(";") if stmt != ""]
sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
return ";\n".join(sql)


def should_use_external_connection(hook) -> bool:
# TODO: Add checking overrides
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"]
24 changes: 6 additions & 18 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os
from contextlib import closing, contextmanager
from functools import cached_property
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
Expand Down Expand Up @@ -177,6 +178,7 @@ def _get_field(self, extra_dict, field_name):
return extra_dict[field_name] or None
return extra_dict.get(backcompat_key) or None

@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
"""Fetch connection params as a dict.
Expand Down Expand Up @@ -269,7 +271,7 @@ def _get_conn_params(self) -> dict[str, str | None]:

def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
conn_params = self._get_conn_params()
conn_params = self._get_conn_params
return self._conn_params_to_sqlalchemy_uri(conn_params)

def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:
Expand All @@ -283,7 +285,7 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:

def get_conn(self) -> SnowflakeConnection:
"""Return a snowflake.connection object."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params
conn = connector.connect(**conn_config)
return conn

Expand All @@ -294,7 +296,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
:return: the created engine.
"""
engine_kwargs = engine_kwargs or {}
conn_params = self._get_conn_params()
conn_params = self._get_conn_params
if "insecure_mode" in conn_params:
engine_kwargs.setdefault("connect_args", {})
engine_kwargs["connect_args"]["insecure_mode"] = True
Expand Down Expand Up @@ -458,21 +460,7 @@ def get_openlineage_database_dialect(self, _) -> str:
return "snowflake"

def get_openlineage_default_schema(self) -> str | None:
"""
Attempt to get current schema.
Usually ``SELECT CURRENT_SCHEMA();`` should work.
However, apparently you may set ``database`` without ``schema``
and get results from ``SELECT CURRENT_SCHEMAS();`` but not
from ``SELECT CURRENT_SCHEMA();``.
It still may return nothing if no database is set in connection.
"""
schema = self._get_conn_params()["schema"]
if not schema:
current_schemas = self.get_first("SELECT PARSE_JSON(CURRENT_SCHEMAS())[0]::string;")[0]
if current_schemas:
_, schema = current_schemas.split(".")
return schema
return self._get_conn_params["schema"]

def _get_openlineage_authority(self, _) -> str:
from openlineage.common.provider.snowflake import fix_snowflake_sqlalchemy_uri
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
@property
def account_identifier(self) -> str:
"""Returns snowflake account identifier."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params
account_identifier = f"https://{conn_config['account']}"

if conn_config["region"]:
Expand Down Expand Up @@ -147,7 +147,7 @@ def execute_query(
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
"""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params

req_id = uuid.uuid4()
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements"
Expand Down Expand Up @@ -186,7 +186,7 @@ def execute_query(

def get_headers(self) -> dict[str, Any]:
"""Form auth headers based on either OAuth token or JWT token from private key."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params

# Use OAuth if refresh_token and client_id and client_secret are provided
if all(
Expand Down Expand Up @@ -225,7 +225,7 @@ def get_headers(self) -> dict[str, Any]:

def get_oauth_token(self) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params
url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"
data = {
"grant_type": "refresh_token",
Expand Down
6 changes: 4 additions & 2 deletions tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def get_db_hook(self):
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
"AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
"SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')"
Expand All @@ -171,7 +172,8 @@ def get_db_hook(self):
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
),
]

Expand Down
37 changes: 14 additions & 23 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_hook_should_support_prepare_basic_conn_params_and_uri(
):
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == expected_uri
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == expected_conn_params
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == expected_conn_params

def test_get_conn_params_should_support_private_auth_in_connection(
self, encrypted_temporary_private_key: Path
Expand All @@ -288,7 +288,7 @@ def test_get_conn_params_should_support_private_auth_in_connection(
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

@pytest.mark.parametrize("include_params", [True, False])
def test_hook_param_beats_extra(self, include_params):
Expand All @@ -311,7 +311,7 @@ def test_hook_param_beats_extra(self, include_params):
assert hook_params != extras
assert SnowflakeHook(
snowflake_conn_id="test_conn", **(hook_params if include_params else {})
)._get_conn_params() == {
)._get_conn_params == {
"user": None,
"password": "",
"application": "AIRFLOW",
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_extra_short_beats_long(self, include_unprefixed):
).get_uri(),
):
assert list(extras.values()) != list(extras_prefixed.values())
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == {
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == {
"user": None,
"password": "",
"application": "AIRFLOW",
Expand All @@ -366,7 +366,7 @@ def test_get_conn_params_should_support_private_auth_with_encrypted_key(
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

def test_get_conn_params_should_support_private_auth_with_unencrypted_key(
self, non_encrypted_temporary_private_key
Expand All @@ -384,15 +384,15 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key(
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
connection_kwargs["password"] = ""
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
connection_kwargs["password"] = _PASSWORD
with mock.patch.dict(
"os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
), pytest.raises(TypeError, match="Password was given but private key is not encrypted."):
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

def test_get_conn_params_should_fail_on_invalid_key(self):
connection_kwargs = {
Expand All @@ -419,8 +419,7 @@ def test_should_add_partner_info(self):
AIRFLOW_SNOWFLAKE_PARTNER="PARTNER_NAME",
):
assert (
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()["application"]
== "PARTNER_NAME"
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params["application"] == "PARTNER_NAME"
)

def test_get_conn_should_call_connect(self):
Expand All @@ -429,7 +428,7 @@ def test_get_conn_should_call_connect(self):
), mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as mock_connector:
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn = hook.get_conn()
mock_connector.connect.assert_called_once_with(**hook._get_conn_params())
mock_connector.connect.assert_called_once_with(**hook._get_conn_params)
assert mock_connector.connect.return_value == conn

def test_get_sqlalchemy_engine_should_support_pass_auth(self):
Expand Down Expand Up @@ -516,7 +515,7 @@ def test_hook_parameters_should_take_precedence(self):
"session_parameters": {"AA": "AAA"},
"user": "user",
"warehouse": "TEST_WAREHOUSE",
} == hook._get_conn_params()
} == hook._get_conn_params
assert (
"snowflake://user:pw@TEST_ACCOUNT.TEST_REGION/TEST_DATABASE/TEST_SCHEMA"
"?application=AIRFLOW&authenticator=TEST_AUTH&role=TEST_ROLE&warehouse=TEST_WAREHOUSE"
Expand Down Expand Up @@ -587,22 +586,14 @@ def test_empty_sql_parameter(self):
hook.run(sql=empty_statement)
assert err.value.args[0] == "List of SQL statements is empty"

@pytest.mark.parametrize(
"returned_schema,expected_schema",
[([None], ""), (["DATABASE.SCHEMA"], "SCHEMA")],
)
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_no_schema_set(
self, mock_get_first, returned_schema, expected_schema
):
def test_get_openlineage_default_schema_with_no_schema_set(self):
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"schema": None,
"schema": "PUBLIC",
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
mock_get_first.return_value = returned_schema
assert hook.get_openlineage_default_schema() == expected_schema
assert hook.get_openlineage_default_schema() == "PUBLIC"

@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first):
Expand Down

0 comments on commit 3b26f58

Please sign in to comment.