Skip to content

Commit

Permalink
Add OpenLineage support for Trino. (#32910)
Browse files Browse the repository at this point in the history
Signed-off-by: Jakub Dardzinski <kuba0221@gmail.com>
  • Loading branch information
JDarDagran committed Aug 24, 2023
1 parent 2d86252 commit 626d3da
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 18 deletions.
17 changes: 14 additions & 3 deletions airflow/providers/openlineage/utils/sql.py
Expand Up @@ -155,11 +155,22 @@ def create_information_schema_query(
metadata = MetaData(sqlalchemy_engine)
select_statements = []
for db, schema_mapping in tables_hierarchy.items():
schema, table_name = information_schema_table_name.split(".")
# Information schema table name is expected to be "< information_schema schema >.<view/table name>"
# usually "information_schema.columns". In order to use table identifier correct for various table
# we need to pass first part of dot-separated identifier as `schema` argument to `sqlalchemy.Table`.
if db:
schema = f"{db}.{schema}"
# Use database as first part of table identifier.
schema = db
table_name = information_schema_table_name
else:
# When no database passed, use schema as first part of table identifier.
schema, table_name = information_schema_table_name.split(".")
information_schema_table = Table(
table_name, metadata, *[Column(column) for column in columns], schema=schema
table_name,
metadata,
*[Column(column) for column in columns],
schema=schema,
quote=False,
)
filter_clauses = create_filter_clauses(schema_mapping, information_schema_table, uppercase_names)
select_statements.append(information_schema_table.select().filter(*filter_clauses))
Expand Down
29 changes: 29 additions & 0 deletions airflow/providers/trino/hooks/trino.py
Expand Up @@ -233,3 +233,32 @@ def _serialize_cell(cell: Any, conn: Connection | None = None) -> Any:
:return: The cell
"""
return cell

def get_openlineage_database_info(self, connection):
"""Returns Trino specific information for OpenLineage."""
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme="trino",
authority=DbApiHook.get_openlineage_authority_part(
connection, default_port=trino.constants.DEFAULT_PORT
),
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"data_type",
"table_catalog",
],
database=connection.extra_dejson.get("catalog", "hive"),
is_information_schema_cross_db=True,
)

def get_openlineage_database_dialect(self, _):
"""Returns Trino dialect."""
return "trino"

def get_openlineage_default_schema(self):
"""Returns Trino default schema."""
return trino.constants.DEFAULT_SCHEMA
4 changes: 2 additions & 2 deletions dev/breeze/tests/test_provider_dependencies.py
Expand Up @@ -25,7 +25,7 @@ def test_get_downstream_only():
related_providers = get_related_providers(
"trino", upstream_dependencies=False, downstream_dependencies=True
)
assert {"google", "common.sql"} == related_providers
assert {"openlineage", "google", "common.sql"} == related_providers


def test_get_upstream_only():
Expand All @@ -39,7 +39,7 @@ def test_both():
related_providers = get_related_providers(
"trino", upstream_dependencies=True, downstream_dependencies=True
)
assert {"google", "mysql", "common.sql"} == related_providers
assert {"openlineage", "google", "mysql", "common.sql"} == related_providers


def test_none():
Expand Down
3 changes: 2 additions & 1 deletion generated/provider_dependencies.json
Expand Up @@ -892,7 +892,8 @@
],
"cross-providers-deps": [
"common.sql",
"google"
"google",
"openlineage"
],
"excluded-python-versions": []
},
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/providers/trino/hooks/test_trino.py
Expand Up @@ -22,6 +22,7 @@
import pytest

from airflow.providers.trino.hooks.trino import TrinoHook
from airflow.providers.trino.operators.trino import TrinoOperator


@pytest.mark.integration("trino")
Expand All @@ -46,3 +47,13 @@ def test_should_record_records_with_kerberos_auth(self):
sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
records = hook.get_records(sql)
assert [["Customer#000000001"], ["Customer#000000002"], ["Customer#000000003"]] == records

@mock.patch.dict("os.environ", AIRFLOW_CONN_TRINO_DEFAULT="trino://airflow@trino:8080/")
def test_openlineage_methods(self):
op = TrinoOperator(task_id="trino_test", sql="SELECT name FROM tpch.sf1.customer LIMIT 3")
op.execute({})
lineage = op.get_openlineage_facets_on_start()
assert lineage.inputs[0].namespace == "trino://trino:8080"
assert lineage.inputs[0].name == "tpch.sf1.customer"
assert "schema" in lineage.inputs[0].facets
assert lineage.job_facets["sql"].query == "SELECT name FROM tpch.sf1.customer LIMIT 3"
24 changes: 12 additions & 12 deletions tests/providers/openlineage/utils/test_sql.py
Expand Up @@ -327,17 +327,17 @@ def test_create_create_information_schema_query_cross_db():
information_schema_table_name="information_schema.columns",
tables_hierarchy={"db": {"schema1": ["table1"]}, "db2": {"schema1": ["table2"]}},
)
== 'SELECT "db.information_schema".columns.table_schema, "db.information_schema".columns.table_name, '
'"db.information_schema".columns.column_name, "db.information_schema".columns.ordinal_position, '
'"db.information_schema".columns.data_type \n'
'FROM "db.information_schema".columns \n'
"WHERE \"db.information_schema\".columns.table_schema = 'schema1' "
"AND \"db.information_schema\".columns.table_name IN ('table1') "
== "SELECT db.information_schema.columns.table_schema, db.information_schema.columns.table_name, "
"db.information_schema.columns.column_name, db.information_schema.columns.ordinal_position, "
"db.information_schema.columns.data_type \n"
"FROM db.information_schema.columns \n"
"WHERE db.information_schema.columns.table_schema = 'schema1' "
"AND db.information_schema.columns.table_name IN ('table1') "
"UNION ALL "
'SELECT "db2.information_schema".columns.table_schema, "db2.information_schema".columns.table_name, '
'"db2.information_schema".columns.column_name, "db2.information_schema".columns.ordinal_position, '
'"db2.information_schema".columns.data_type \n'
'FROM "db2.information_schema".columns \n'
"WHERE \"db2.information_schema\".columns.table_schema = 'schema1' "
"AND \"db2.information_schema\".columns.table_name IN ('table2')"
"SELECT db2.information_schema.columns.table_schema, db2.information_schema.columns.table_name, "
"db2.information_schema.columns.column_name, db2.information_schema.columns.ordinal_position, "
"db2.information_schema.columns.data_type \n"
"FROM db2.information_schema.columns \n"
"WHERE db2.information_schema.columns.table_schema = 'schema1' "
"AND db2.information_schema.columns.table_name IN ('table2')"
)
64 changes: 64 additions & 0 deletions tests/providers/trino/operators/test_trino.py
Expand Up @@ -20,8 +20,12 @@
from unittest import mock

import pytest
from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet
from openlineage.client.run import Dataset

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models.connection import Connection
from airflow.providers.trino.hooks.trino import TrinoHook
from airflow.providers.trino.operators.trino import TrinoOperator

TRINO_CONN_ID = "test_trino"
Expand Down Expand Up @@ -49,3 +53,63 @@ def test_execute(self, mock_get_db_hook):
parameters=None,
return_last=True,
)


def test_execute_openlineage_events():
DB_NAME = "tpch"
DB_SCHEMA_NAME = "sf1"

class TrinoHookForTests(TrinoHook):
get_conn = mock.MagicMock(name="conn")
get_connection = mock.MagicMock()

def get_first(self, *_):
return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]

dbapi_hook = TrinoHookForTests()

class TrinoOperatorForTest(TrinoOperator):
def get_db_hook(self):
return dbapi_hook

sql = "SELECT name FROM tpch.sf1.customer LIMIT 3"
op = TrinoOperatorForTest(task_id="trino-operator", sql=sql)
rows = [
(DB_SCHEMA_NAME, "customer", "custkey", 1, "bigint", DB_NAME),
(DB_SCHEMA_NAME, "customer", "name", 2, "varchar(25)", DB_NAME),
(DB_SCHEMA_NAME, "customer", "address", 3, "varchar(40)", DB_NAME),
(DB_SCHEMA_NAME, "customer", "nationkey", 4, "bigint", DB_NAME),
(DB_SCHEMA_NAME, "customer", "phone", 5, "varchar(15)", DB_NAME),
(DB_SCHEMA_NAME, "customer", "acctbal", 6, "double", DB_NAME),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="trino_default",
conn_type="trino",
host="trino",
port=8080,
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_start()
assert lineage.inputs == [
Dataset(
namespace="trino://trino:8080",
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.customer",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="custkey", type="bigint"),
SchemaField(name="name", type="varchar(25)"),
SchemaField(name="address", type="varchar(40)"),
SchemaField(name="nationkey", type="bigint"),
SchemaField(name="phone", type="varchar(15)"),
SchemaField(name="acctbal", type="double"),
]
)
},
)
]

assert len(lineage.outputs) == 0

assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}

0 comments on commit 626d3da

Please sign in to comment.