Skip to content

Commit

Permalink
openlineage, snowflake: add OpenLineage support for Snowflake (#31696)
Browse files Browse the repository at this point in the history
* Add OpenLineage support for SnowflakeOperator.

Signed-off-by: Jakub Dardzinski <kuba0221@gmail.com>

* Change how default schema is retrieved.

Signed-off-by: Jakub Dardzinski <kuba0221@gmail.com>

---------

Signed-off-by: Jakub Dardzinski <kuba0221@gmail.com>
  • Loading branch information
JDarDagran committed Jul 21, 2023
1 parent 98a9990 commit 5b082c3
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 1 deletion.
71 changes: 70 additions & 1 deletion airflow/providers/snowflake/hooks/snowflake.py
Expand Up @@ -22,7 +22,8 @@
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
from urllib.parse import urlparse

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand All @@ -36,6 +37,9 @@
from airflow.utils.strings import to_boolean

T = TypeVar("T")
if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo


def _try_to_boolean(value: Any):
Expand Down Expand Up @@ -448,3 +452,68 @@ def _get_cursor(self, conn: Any, return_dictionaries: bool):
finally:
if cursor is not None:
cursor.close()

def get_openlineage_database_info(self, connection) -> DatabaseInfo:
from airflow.providers.openlineage.sqlparser import DatabaseInfo

database = self.database or self._get_field(connection.extra_dejson, "database")

return DatabaseInfo(
scheme=self.get_openlineage_database_dialect(connection),
authority=self._get_openlineage_authority(connection),
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"data_type",
],
database=database,
is_information_schema_cross_db=True,
is_uppercase_names=True,
)

def get_openlineage_database_dialect(self, _) -> str:
return "snowflake"

def get_openlineage_default_schema(self) -> str | None:
"""
Attempts to get current schema.
Usually ``SELECT CURRENT_SCHEMA();`` should work.
However, apparently you may set ``database`` without ``schema``
and get results from ``SELECT CURRENT_SCHEMAS();`` but not
from ``SELECT CURRENT_SCHEMA();``.
It still may return nothing if no database is set in connection.
"""
schema = self._get_conn_params()["schema"]
if not schema:
current_schemas = self.get_first("SELECT PARSE_JSON(CURRENT_SCHEMAS())[0]::string;")[0]
if current_schemas:
_, schema = current_schemas.split(".")
return schema

def _get_openlineage_authority(self, _) -> str:
from openlineage.common.provider.snowflake import fix_snowflake_sqlalchemy_uri

uri = fix_snowflake_sqlalchemy_uri(self.get_uri())
return urlparse(uri).hostname

def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None:
from openlineage.client.facet import ExternalQueryRunFacet

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import SQLParser

connection = self.get_connection(getattr(self, self.conn_name_attr))
namespace = SQLParser.create_namespace(self.get_database_info(connection))

if self.query_ids:
return OperatorLineage(
run_facets={
"externalQuery": ExternalQueryRunFacet(
externalQueryId=self.query_ids[0], source=namespace
)
}
)
return None
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Expand Up @@ -814,6 +814,7 @@
],
"cross-providers-deps": [
"common.sql",
"openlineage",
"slack"
],
"excluded-python-versions": []
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/snowflake/hooks/test_snowflake.py
Expand Up @@ -621,3 +621,33 @@ def test___ensure_prefixes(self):
"extra__snowflake__private_key_content",
"extra__snowflake__insecure_mode",
]

@pytest.mark.parametrize(
"returned_schema,expected_schema",
[([None], ""), (["DATABASE.SCHEMA"], "SCHEMA")],
)
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_no_schema_set(
self, mock_get_first, returned_schema, expected_schema
):
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"schema": None,
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
mock_get_first.return_value = returned_schema
assert hook.get_openlineage_default_schema() == expected_schema

@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first):
with mock.patch.dict(
"os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
assert hook.get_openlineage_default_schema() == BASE_CONNECTION_KWARGS["schema"]
mock_get_first.assert_not_called()

hook_with_schema_param = SnowflakeHook(snowflake_conn_id="test_conn", schema="my_schema")
assert hook_with_schema_param.get_openlineage_default_schema() == "my_schema"
mock_get_first.assert_not_called()
68 changes: 68 additions & 0 deletions tests/providers/snowflake/operators/test_snowflake_sql.py
Expand Up @@ -21,8 +21,12 @@

import pytest
from databricks.sql.types import Row
from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet
from openlineage.client.run import Dataset

from airflow.models.connection import Connection
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator

DATE = "2017-04-20"
Expand Down Expand Up @@ -138,3 +142,67 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
return_last=return_last,
split_statements=split_statement,
)


def test_execute_openlineage_events():
DB_NAME = "DATABASE"
DB_SCHEMA_NAME = "PUBLIC"

class SnowflakeHookForTests(SnowflakeHook):
get_conn = MagicMock(name="conn")
get_connection = MagicMock()

def get_first(self, *_):
return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]

dbapi_hook = SnowflakeHookForTests()

class SnowflakeOperatorForTest(SnowflakeOperator):
def get_db_hook(self):
return dbapi_hook

sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
order_day_of_week VARCHAR(64) NOT NULL,
order_placed_on TIMESTAMP NOT NULL,
orders_placed INTEGER NOT NULL
);
FORGOT TO COMMENT"""
op = SnowflakeOperatorForTest(task_id="snowflake-operator", sql=sql)
rows = [
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDER_DAY_OF_WEEK", 1, "TEXT"),
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDER_PLACED_ON", 2, "TIMESTAMP_NTZ"),
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDERS_PLACED", 3, "NUMBER"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
extra={
"account": "test_account",
"region": "us-east",
"warehouse": "snow-warehouse",
"database": DB_NAME,
},
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 0
assert lineage.outputs == [
Dataset(
namespace="snowflake://test_account.us-east.aws",
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.POPULAR_ORDERS_DAY_OF_WEEK",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"),
SchemaField(name="ORDER_PLACED_ON", type="TIMESTAMP_NTZ"),
SchemaField(name="ORDERS_PLACED", type="NUMBER"),
]
)
},
)
]

assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}

assert lineage.run_facets["extractionError"].failedTasks == 1

0 comments on commit 5b082c3

Please sign in to comment.