From 47858270b0d4ec608967ed200806d3c0dafcfaf3 Mon Sep 17 00:00:00 2001 From: Phani Kumar <94376113+phanikumv@users.noreply.github.com> Date: Fri, 4 Nov 2022 22:01:46 +0530 Subject: [PATCH] Openlineage support - Add Extractor for `DataframeOperator` (#1183) # Description ## What is the current behavior? We are unable to extract open lineage info from the DataframeOperator. closes: #904 ## What is the new behavior? - Added a method get_openlineage_facets on the Operator (based on [this doc](https://docs.google.com/document/d/1vPsvHejQ24xTbzpz_LYSf0_ixk9oUuBiEUHVEaF9J2U/edit?usp=sharing)) - The "PythonSDKExtractor" built-in https://github.com/astronomer/astro-sdk/issues/898 should be able to work with DataframeOperator.get_openlineage_facets. Tested it to make sure it works ## Does this introduce a breaking change? No ### Checklist - [x] Post the screenshot of how it looks in the Openlineage/Marquez UI - [x] All checks and tests in the CI should pass - [x] Unit tests (90% code coverage or more, https://github.com/astronomer/astro-sdk/issues/191) - [ ] Integration tests (if the feature relates to a new database or external service) - [ ] Docstrings in [reStructuredText](https://peps.python.org/pep-0287/) for each of methods, classes, functions and module-level attributes (including Example DAG on how it should be used) - [ ] Exception handling in case of errors - [ ] Logging (are we exposing useful information to the user? e.g. source and destination) - [ ] Improve the documentation (README, Sphinx, and any other relevant) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> Co-authored-by: Pankaj --- python-sdk/src/astro/lineage/extractor.py | 1 + .../src/astro/sql/operators/dataframe.py | 53 ++++++++++++++++++- python-sdk/tests/extractors/test_extractor.py | 52 ++++++++++++++++++ 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/python-sdk/src/astro/lineage/extractor.py b/python-sdk/src/astro/lineage/extractor.py index 7061728103..6a881b1379 100644 --- a/python-sdk/src/astro/lineage/extractor.py +++ b/python-sdk/src/astro/lineage/extractor.py @@ -34,6 +34,7 @@ def get_operator_classnames(cls) -> list[str]: return [ "AppendOperator", "BaseSQLDecoratedOperator", + "DataframeOperator", "ExportFileOperator", "LoadFileOperator", "MergeOperator", diff --git a/python-sdk/src/astro/sql/operators/dataframe.py b/python-sdk/src/astro/sql/operators/dataframe.py index 90196632d3..4accac2bef 100644 --- a/python-sdk/src/astro/sql/operators/dataframe.py +++ b/python-sdk/src/astro/sql/operators/dataframe.py @@ -11,13 +11,23 @@ try: from airflow.decorators.base import TaskDecorator, task_decorator_factory -except ImportError: +except ImportError: # pragma: no cover from airflow.decorators.base import task_decorator_factory from airflow.decorators import _TaskDecorator as TaskDecorator +from openlineage.client.facet import ( + BaseFacet, + DataSourceDatasetFacet, + OutputStatisticsOutputDatasetFacet, + SchemaDatasetFacet, + SchemaField, +) +from openlineage.client.run import Dataset as OpenlineageDataset + from astro.constants import ColumnCapitalization from astro.databases import create_database from astro.files import File +from astro.lineage.extractor import OpenLineageFacets from astro.sql.operators.base_operator import AstroSQLBaseOperator from astro.sql.table import BaseTable, Table from astro.utils.dataframe import convert_columns_names_capitalization @@ -215,6 +225,47 @@ def _convert_column_capitalization_for_output(function_output, columns_names_cap ) return function_output + def get_openlineage_facets(self, task_instance) -> OpenLineageFacets: # skipcq: PYL-W0613 + """ + Collect the input, output, job and run facets for DataframeOperator + """ + output_dataset: list[OpenlineageDataset] = [] + + if self.output_table and self.output_table.openlineage_emit_temp_table_event(): # pragma: no cover + output_uri = ( + f"{self.output_table.openlineage_dataset_namespace()}" + f"/{self.output_table.openlineage_dataset_name()}" + ) + + output_dataset = [ + OpenlineageDataset( + namespace=self.output_table.openlineage_dataset_namespace(), + name=self.output_table.openlineage_dataset_name(), + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaField( + name=self.schema if self.schema else self.output_table.metadata.schema, + type=self.database + if self.database + else self.output_table.metadata.database, + ) + ] + ), + "dataSource": DataSourceDatasetFacet(name=self.output_table.name, uri=output_uri), + "outputStatistics": OutputStatisticsOutputDatasetFacet( + rowCount=self.output_table.row_count, + ), + }, + ), + ] + + run_facets: dict[str, BaseFacet] = {} + job_facets: dict[str, BaseFacet] = {} + return OpenLineageFacets( + inputs=[], outputs=output_dataset, run_facets=run_facets, job_facets=job_facets + ) + def dataframe( python_callable: Callable | None = None, diff --git a/python-sdk/tests/extractors/test_extractor.py b/python-sdk/tests/extractors/test_extractor.py index f12ed1f539..d85fa9d7a9 100644 --- a/python-sdk/tests/extractors/test_extractor.py +++ b/python-sdk/tests/extractors/test_extractor.py @@ -1,3 +1,4 @@ +import pandas as pd import pendulum import pytest from airflow.models.taskinstance import TaskInstance @@ -286,3 +287,54 @@ def top_five_animations(input_table: Table) -> str: assert task_meta.outputs[0].facets["outputStatistics"].size is None assert len(task_meta.job_facets) > 0 assert task_meta.run_facets == {} + + +@pytest.mark.integration +def test_python_sdk_dataframe_op_extract_on_complete(): + """ + Tests that the custom PythonSDKExtractor is able to process the + operator's metadata that needs to be extracted as per OpenLineage + for DataframeOperator. + """ + + @aql.dataframe(columns_names_capitalization="original") + def aggregate_data(df: pd.DataFrame): + new_df = df + new_df.columns = new_df.columns.str.lower() + return new_df + + test_list = [["a", "b", "c"], ["AA", "BB", "CC"]] + dfList = pd.DataFrame(test_list, columns=["COL_A", "COL_B", "COL_C"]) + test_tbl_name = "test_tbl" + test_schema_name = "test_schema" + test_db_name = "test_db" + + task = ( + aggregate_data( + dfList, + output_table=Table( + name=test_tbl_name, + metadata=Metadata( + schema=test_schema_name, + database=test_db_name, + ), + conn_id="sqlite_default", + ), + ), + ) + + task[0].operator.execute(context=create_context(task[0].operator)) + + tzinfo = pendulum.timezone("UTC") + execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo) + task_instance = TaskInstance(task=task[0].operator, run_id=execution_date) + python_sdk_extractor = PythonSDKExtractor(task[0].operator) + + assert type(python_sdk_extractor.get_operator_classnames()) is list + task_meta = python_sdk_extractor.extract_on_complete(task_instance) + assert task_meta.name == "adhoc_airflow.aggregate_data" + assert task_meta.outputs[0].facets["schema"].fields[0].name == test_schema_name + assert task_meta.outputs[0].facets["schema"].fields[0].type == test_db_name + assert task_meta.outputs[0].facets["dataSource"].name == test_tbl_name + assert task_meta.outputs[0].facets["outputStatistics"].rowCount == len(test_list) + assert task_meta.run_facets == {}