diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index a5654d8bbfb65..0f80d52cf307e 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -177,11 +177,30 @@ def extract_inlets_and_outlets( if d: task_metadata.outputs.append(d) + @staticmethod + def convert_to_ol_dataset_from_object_storage_uri(uri: str): + from urllib.parse import urlparse + + from openlineage.client.run import Dataset + + try: + scheme, netloc, path, params, _, _ = urlparse(uri) + except Exception: + return None + if scheme.startswith("s3"): + return Dataset(namespace=f"s3://{netloc}", name=path.lstrip("/")) + elif scheme.startswith(("gcs", "gs")): + return Dataset(namespace=f"gs://{netloc}", name=path.lstrip("/")) + elif "/" not in uri: + return None + return Dataset(namespace=scheme, name=f"{netloc}{path}") + @staticmethod def convert_to_ol_dataset(obj): + from openlineage.client.facet import SchemaDatasetFacet, SchemaField from openlineage.client.run import Dataset - from airflow.lineage.entities import Table + from airflow.lineage.entities import File, Table if isinstance(obj, Dataset): return obj @@ -189,8 +208,24 @@ def convert_to_ol_dataset(obj): return Dataset( namespace=f"{obj.cluster}", name=f"{obj.database}.{obj.name}", - facets={}, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaField( + name=column.name, + type=column.data_type, + description=column.description, + ) + for column in obj.columns + ] + ) + } + if obj.columns + else {}, ) + + elif isinstance(obj, File): + return ExtractorManager.convert_to_ol_dataset_from_object_storage_uri(obj.url) else: return None diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index bab56abead2a1..85c2c306555b7 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -159,7 +159,6 @@ def test_providers_modules_should_have_tests(self): "tests/providers/microsoft/azure/operators/test_adls.py", "tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py", "tests/providers/mongo/sensors/test_mongo.py", - "tests/providers/openlineage/extractors/test_manager.py", "tests/providers/openlineage/plugins/test_adapter.py", "tests/providers/openlineage/plugins/test_facets.py", "tests/providers/openlineage/test_sqlparser.py", diff --git a/tests/providers/openlineage/extractors/test_manager.py b/tests/providers/openlineage/extractors/test_manager.py new file mode 100644 index 0000000000000..23760e6e14f2f --- /dev/null +++ b/tests/providers/openlineage/extractors/test_manager.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from openlineage.client.facet import SchemaDatasetFacet, SchemaField +from openlineage.client.run import Dataset + +from airflow.lineage.entities import Column, File, Table +from airflow.providers.openlineage.extractors.manager import ExtractorManager + + +@pytest.mark.parametrize( + ("uri", "dataset"), + ( + ("s3://bucket1/dir1/file1", Dataset(namespace="s3://bucket1", name="dir1/file1")), + ("gs://bucket2/dir2/file2", Dataset(namespace="gs://bucket2", name="dir2/file2")), + ("gcs://bucket3/dir3/file3", Dataset(namespace="gs://bucket3", name="dir3/file3")), + ("https://test.com", Dataset(namespace="https", name="test.com")), + ("https://test.com?param1=test1¶m2=test2", Dataset(namespace="https", name="test.com")), + ("not_an_url", None), + ), +) +def test_convert_to_ol_dataset_from_object_storage_uri(uri, dataset): + result = ExtractorManager.convert_to_ol_dataset_from_object_storage_uri(uri) + assert result == dataset + + +@pytest.mark.parametrize( + ("obj", "dataset"), + ( + ( + Dataset(namespace="n1", name="f1"), + Dataset(namespace="n1", name="f1"), + ), + (File(url="s3://bucket1/dir1/file1"), Dataset(namespace="s3://bucket1", name="dir1/file1")), + (File(url="gs://bucket2/dir2/file2"), Dataset(namespace="gs://bucket2", name="dir2/file2")), + (File(url="https://test.com"), Dataset(namespace="https", name="test.com")), + (Table(cluster="c1", database="d1", name="t1"), Dataset(namespace="c1", name="d1.t1")), + ("gs://bucket2/dir2/file2", None), + ("not_an_url", None), + ), +) +def test_convert_to_ol_dataset(obj, dataset): + result = ExtractorManager.convert_to_ol_dataset(obj) + assert result == dataset + + +def test_convert_to_ol_dataset_from_table_with_columns(): + table = Table( + cluster="c1", + database="d1", + name="t1", + columns=[ + Column(name="col1", description="desc1", data_type="type1"), + Column(name="col2", description="desc2", data_type="type2"), + ], + ) + result = ExtractorManager.convert_to_ol_dataset(table) + expected_facets = { + "schema": SchemaDatasetFacet( + fields=[ + SchemaField( + name="col1", + type="type1", + description="desc1", + ), + SchemaField( + name="col2", + type="type2", + description="desc2", + ), + ] + ) + } + assert result.namespace == "c1" + assert result.name == "d1.t1" + assert result.facets == expected_facets