Skip to content

Commit

Permalink
Send column lineage from SQL operators. (#34843)
Browse files Browse the repository at this point in the history
Signed-off-by: Jakub Dardzinski <kuba0221@gmail.com>
  • Loading branch information
JDarDagran committed Oct 25, 2023
1 parent 0bb5631 commit 0940d09
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 51 deletions.
53 changes: 52 additions & 1 deletion airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@

import sqlparse
from attrs import define
from openlineage.client.facet import BaseFacet, ExtractionError, ExtractionErrorRunFacet, SqlJobFacet
from openlineage.client.facet import (
BaseFacet,
ColumnLineageDatasetFacet,
ColumnLineageDatasetFacetFieldsAdditional,
ColumnLineageDatasetFacetFieldsAdditionalInputFields,
ExtractionError,
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.common.sql import DbTableMeta, SqlMeta, parse

from airflow.providers.openlineage.extractors.base import OperatorLineage
Expand Down Expand Up @@ -143,6 +151,47 @@ def parse_table_schemas(
else None,
)

def attach_column_lineage(
self, datasets: list[Dataset], database: str | None, parse_result: SqlMeta
) -> None:
"""
Attaches column lineage facet to the list of datasets.
Note that currently each dataset has the same column lineage information set.
This would be a matter of change after OpenLineage SQL Parser improvements.
"""
if not len(parse_result.column_lineage):
return
for dataset in datasets:
dataset.facets["columnLineage"] = ColumnLineageDatasetFacet(
fields={
column_lineage.descendant.name: ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace=dataset.namespace,
name=".".join(
filter(
None,
(
column_meta.origin.database or database,
column_meta.origin.schema or self.default_schema,
column_meta.origin.name,
),
)
)
if column_meta.origin
else "",
field=column_meta.name,
)
for column_meta in column_lineage.lineage
],
transformationType="",
transformationDescription="",
)
for column_lineage in parse_result.column_lineage
}
)

def generate_openlineage_metadata_from_sql(
self,
sql: list[str] | str,
Expand Down Expand Up @@ -198,6 +247,8 @@ def generate_openlineage_metadata_from_sql(
sqlalchemy_engine=sqlalchemy_engine,
)

self.attach_column_lineage(outputs, database or database_info.database, parse_result)

return OperatorLineage(
inputs=inputs,
outputs=outputs,
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/providers/openlineage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
16 changes: 16 additions & 0 deletions tests/integration/providers/openlineage/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
181 changes: 131 additions & 50 deletions tests/providers/openlineage/utils/test_sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from unittest.mock import MagicMock

import pytest
from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet
from openlineage.client.facet import (
ColumnLineageDatasetFacet,
ColumnLineageDatasetFacetFieldsAdditional,
ColumnLineageDatasetFacetFieldsAdditionalInputFields,
SchemaDatasetFacet,
SchemaField,
SqlJobFacet,
)
from openlineage.client.run import Dataset
from openlineage.common.sql import DbTableMeta

Expand All @@ -33,16 +40,6 @@

NAMESPACE = "test_namespace"

SCHEMA_FACET = SchemaDatasetFacet(
fields=[
SchemaField(name="ID", type="int4"),
SchemaField(name="AMOUNT_OFF", type="int4"),
SchemaField(name="CUSTOMER_EMAIL", type="varchar"),
SchemaField(name="STARTS_ON", type="timestamp"),
SchemaField(name="ENDS_ON", type="timestamp"),
]
)


def normalize_name_lower(name: str) -> str:
return name.lower()
Expand All @@ -56,7 +53,8 @@ def test_get_tables_hierarchy(self):

# base check with db, no cross db
assert SQLParser._get_tables_hierarchy(
[DbTableMeta("Db.Schema1.Table1"), DbTableMeta("Db.Schema2.Table2")], normalize_name_lower
[DbTableMeta("Db.Schema1.Table1"), DbTableMeta("Db.Schema2.Table2")],
normalize_name_lower,
) == {None: {"schema1": ["Table1"], "schema2": ["Table2"]}}

# same, with cross db
Expand Down Expand Up @@ -148,20 +146,42 @@ def rows(name):
]

hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [
rows("TABLE_IN"),
rows("TABLE_OUT"),
rows("top_delivery_times"),
rows("popular_orders_day_of_week"),
]

expected_schema_facet = SchemaDatasetFacet(
fields=[
SchemaField(name="ID", type="int4"),
SchemaField(name="AMOUNT_OFF", type="int4"),
SchemaField(name="CUSTOMER_EMAIL", type="varchar"),
SchemaField(name="STARTS_ON", type="timestamp"),
SchemaField(name="ENDS_ON", type="timestamp"),
]
)

expected = (
[Dataset(namespace=NAMESPACE, name="PUBLIC.TABLE_IN", facets={"schema": SCHEMA_FACET})],
[Dataset(namespace=NAMESPACE, name="PUBLIC.TABLE_OUT", facets={"schema": SCHEMA_FACET})],
[
Dataset(
namespace=NAMESPACE,
name="PUBLIC.top_delivery_times",
facets={"schema": expected_schema_facet},
)
],
[
Dataset(
namespace=NAMESPACE,
name="PUBLIC.popular_orders_day_of_week",
facets={"schema": expected_schema_facet},
)
],
)

assert expected == parser.parse_table_schemas(
hook=hook,
namespace=NAMESPACE,
inputs=[DbTableMeta("TABLE_IN")],
outputs=[DbTableMeta("TABLE_OUT")],
inputs=[DbTableMeta("top_delivery_times")],
outputs=[DbTableMeta("popular_orders_day_of_week")],
database_info=db_info,
)

Expand All @@ -173,64 +193,125 @@ def test_generate_openlineage_metadata_from_sql(self, mock_parse, parser_returns

hook = MagicMock()

def rows(schema, table):
return [
(schema, table, "ID", 1, "int4"),
(schema, table, "AMOUNT_OFF", 2, "int4"),
(schema, table, "CUSTOMER_EMAIL", 3, "varchar"),
(schema, table, "STARTS_ON", 4, "timestamp"),
(schema, table, "ENDS_ON", 5, "timestamp"),
]
returned_schema = DB_SCHEMA_NAME if parser_returns_schema else None
returned_rows = [
[
(returned_schema, "top_delivery_times", "order_id", 1, "int4"),
(
returned_schema,
"top_delivery_times",
"order_placed_on",
2,
"timestamp",
),
(returned_schema, "top_delivery_times", "customer_email", 3, "varchar"),
],
[
(
returned_schema,
"popular_orders_day_of_week",
"order_day_of_week",
1,
"varchar",
),
(
returned_schema,
"popular_orders_day_of_week",
"order_placed_on",
2,
"timestamp",
),
(
returned_schema,
"popular_orders_day_of_week",
"orders_placed",
3,
"int4",
),
],
]

sql = """CREATE TABLE table_out (
ID int,
AMOUNT_OFF int,
CUSTOMER_EMAIL varchar,
STARTS_ON timestamp,
ENDS_ON timestamp
sql = """INSERT INTO popular_orders_day_of_week (order_day_of_week)
SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week
FROM top_delivery_times
--irrelevant comment
)
;
"""

hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [
rows(DB_SCHEMA_NAME if parser_returns_schema else None, "TABLE_IN"),
rows(DB_SCHEMA_NAME if parser_returns_schema else None, "TABLE_OUT"),
]
hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = returned_rows

mock_sql_meta = MagicMock()
if parser_returns_schema:
mock_sql_meta.in_tables = [DbTableMeta("PUBLIC.TABLE_IN")]
mock_sql_meta.out_tables = [DbTableMeta("PUBLIC.TABLE_OUT")]
mock_sql_meta.in_tables = [DbTableMeta("PUBLIC.top_delivery_times")]
mock_sql_meta.out_tables = [DbTableMeta("PUBLIC.popular_orders_day_of_week")]
else:
mock_sql_meta.in_tables = [DbTableMeta("TABLE_IN")]
mock_sql_meta.out_tables = [DbTableMeta("TABLE_OUT")]
mock_sql_meta.in_tables = [DbTableMeta("top_delivery_times")]
mock_sql_meta.out_tables = [DbTableMeta("popular_orders_day_of_week")]
mock_column_lineage = MagicMock()
mock_column_lineage.descendant.name = "order_day_of_week"
mock_lineage = MagicMock()
mock_lineage.name = "order_placed_on"
mock_lineage.origin.name = "top_delivery_times"
mock_lineage.origin.database = None
mock_lineage.origin.schema = "PUBLIC" if parser_returns_schema else None
mock_column_lineage.lineage = [mock_lineage]

mock_sql_meta.column_lineage = [mock_column_lineage]
mock_sql_meta.errors = []

mock_parse.return_value = mock_sql_meta

formatted_sql = """CREATE TABLE table_out (
ID int,
AMOUNT_OFF int,
CUSTOMER_EMAIL varchar,
STARTS_ON timestamp,
ENDS_ON timestamp
formatted_sql = """INSERT INTO popular_orders_day_of_week (order_day_of_week)
SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week
FROM top_delivery_times
)"""
expected_schema = "PUBLIC" if parser_returns_schema else "ANOTHER_SCHEMA"
expected = OperatorLineage(
inputs=[
Dataset(
namespace="myscheme://host:port",
name=f"{expected_schema}.TABLE_IN",
facets={"schema": SCHEMA_FACET},
name=f"{expected_schema}.top_delivery_times",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_id", type="int4"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="customer_email", type="varchar"),
]
)
},
)
],
outputs=[
Dataset(
namespace="myscheme://host:port",
name=f"{expected_schema}.TABLE_OUT",
facets={"schema": SCHEMA_FACET},
name=f"{expected_schema}.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
]
),
"columnLineage": ColumnLineageDatasetFacet(
fields={
"order_day_of_week": ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace="myscheme://host:port",
name=f"{expected_schema}.top_delivery_times",
field="order_placed_on",
)
],
transformationDescription="",
transformationType="",
)
}
),
},
)
],
job_facets={"sql": SqlJobFacet(query=formatted_sql)},
Expand Down

0 comments on commit 0940d09

Please sign in to comment.