From 9a311e7cff7adef37210d1721c3c4e3f115e054b Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Thu, 2 Mar 2023 21:24:47 +0545 Subject: [PATCH] Add the DatabaseDataProvider and SqliteDataProvider (#1777) # Description ## What is the current behavior? - Add `Dataprovider` for Databases - Add `DataProvider` for Sqlite - Add non-native transfer implementation for s3 to Sqlite - Add non-native transfer implementation for GCS to Sqlite - Add example DAG closes: #1731 ## What is the new behavior? - Add `Dataprovider` for Databases - Add `DataProvider` for Sqlite - Add non-native transfer implementation for s3 to Sqlite - Add non-native transfer implementation for GCS to Sqlite - Add example DAG Following is the recording of example DAG running: https://astronomer.slack.com/archives/C03868KGF2Q/p1676662728169199 ## Does this introduce a breaking change? No ### Checklist - [x] Created tests which fail without the change (if possible) - [x] Extended the README / documentation, if necessary --------- Co-authored-by: utkarsh sharma Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../example_universal_transfer_operator.py | 31 +- pyproject.toml | 10 +- src/universal_transfer_operator/constants.py | 1 + .../data_providers/__init__.py | 16 +- .../data_providers/base.py | 8 + .../data_providers/database/__init__.py | 0 .../data_providers/database/base.py | 615 ++++++++++++++++++ .../data_providers/database/sqlite.py | 139 ++++ .../data_providers/filesystem/__init__.py | 44 ++ .../data_providers/filesystem/aws/s3.py | 8 + .../data_providers/filesystem/base.py | 13 +- .../filesystem/google/cloud/gcs.py | 11 +- .../datasets/file/base.py | 6 + .../datasets/table.py | 138 +++- src/universal_transfer_operator/settings.py | 19 + test-connections.yaml | 142 ++++ tests/conftest.py | 55 ++ .../test_data_provider/test_data_provider.py | 3 + .../test_database/test_base.py | 73 +++ .../test_database/test_sqlite.py | 164 +++++ .../test_filesystem/test_sftp.py | 6 +- tests/utils/test_utils.py | 13 + 22 files changed, 1472 insertions(+), 43 deletions(-) create mode 100644 src/universal_transfer_operator/data_providers/database/__init__.py create mode 100644 src/universal_transfer_operator/data_providers/database/base.py create mode 100644 src/universal_transfer_operator/data_providers/database/sqlite.py create mode 100644 test-connections.yaml create mode 100644 tests/test_data_provider/test_database/test_base.py create mode 100644 tests/test_data_provider/test_database/test_sqlite.py diff --git a/example_dags/example_universal_transfer_operator.py b/example_dags/example_universal_transfer_operator.py index 5382c82..dae2622 100644 --- a/example_dags/example_universal_transfer_operator.py +++ b/example_dags/example_universal_transfer_operator.py @@ -3,12 +3,15 @@ from airflow import DAG -from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.constants import FileType, TransferMode from universal_transfer_operator.datasets.file.base import File from universal_transfer_operator.datasets.table import Metadata, Table from universal_transfer_operator.integrations.fivetran import Connector, Destination, FiveTranOptions, Group from universal_transfer_operator.universal_transfer_operator import UniversalTransferOperator +s3_bucket = os.getenv("S3_BUCKET", "s3://astro-sdk-test") +gcs_bucket = os.getenv("GCS_BUCKET", "gs://uto-test") + with DAG( "example_universal_transfer_operator", schedule_interval=None, @@ -17,22 +20,36 @@ ) as dag: transfer_non_native_gs_to_s3 = UniversalTransferOperator( task_id="transfer_non_native_gs_to_s3", - source_dataset=File(path="gs://uto-test/uto/", conn_id="google_cloud_default"), - destination_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + source_dataset=File(path=f"{gcs_bucket}/uto/", conn_id="google_cloud_default"), + destination_dataset=File(path=f"{s3_bucket}/uto/", conn_id="aws_default"), ) transfer_non_native_s3_to_gs = UniversalTransferOperator( task_id="transfer_non_native_s3_to_gs", - source_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + source_dataset=File(path=f"{s3_bucket}/uto/", conn_id="aws_default"), destination_dataset=File( - path="gs://uto-test/uto/", + path=f"{gcs_bucket}/uto/", conn_id="google_cloud_default", ), ) + transfer_non_native_s3_to_sqlite = UniversalTransferOperator( + task_id="transfer_non_native_s3_to_sqlite", + source_dataset=File(path=f"{s3_bucket}/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV), + destination_dataset=Table(name="uto_s3_table", conn_id="sqlite_default"), + ) + + transfer_non_native_gs_to_sqlite = UniversalTransferOperator( + task_id="transfer_non_native_gs_to_sqlite", + source_dataset=File( + path=f"{gcs_bucket}/uto/csv_files/", conn_id="google_cloud_default", filetype=FileType.CSV + ), + destination_dataset=Table(name="uto_gs_table", conn_id="sqlite_default"), + ) + transfer_fivetran_with_connector_id = UniversalTransferOperator( task_id="transfer_fivetran_with_connector_id", - source_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + source_dataset=File(path=f"{s3_bucket}/uto/", conn_id="aws_default"), destination_dataset=Table(name="fivetran_test", conn_id="snowflake_default"), transfer_mode=TransferMode.THIRDPARTY, transfer_params=FiveTranOptions(conn_id="fivetran_default", connector_id="filing_muppet"), @@ -40,7 +57,7 @@ transfer_fivetran_without_connector_id = UniversalTransferOperator( task_id="transfer_fivetran_without_connector_id", - source_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + source_dataset=File(path=f"{s3_bucket}/uto/", conn_id="aws_default"), destination_dataset=Table( name="fivetran_test", conn_id="snowflake_default", diff --git a/pyproject.toml b/pyproject.toml index dcc7c73..32be062 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,10 @@ google = [ snowflake = [ "apache-airflow-providers-snowflake", "snowflake-sqlalchemy>=1.2.0", - "snowflake-connector-python[pandas]", + "snowflake-connector-python[pandas]<3.0.0", + # pinning snowflake-connector-python[pandas]<3.0.0 due to a conflict in snowflake-connector-python/pyarrow/google + # packages and pandas-gbq/google packages which is forcing pandas-gbq of version 0.13.2 installed, which is not + # compatible with pandas 1.5.3 ] amazon = [ @@ -79,7 +82,10 @@ all = [ "apache-airflow-providers-google>=6.4.0", "apache-airflow-providers-snowflake", "smart-open[all]>=5.2.1", - "snowflake-connector-python[pandas]", + "snowflake-connector-python[pandas]<3.0.0", + # pinning snowflake-connector-python[pandas]<3.0.0 due to a conflict in snowflake-connector-python/pyarrow/google + # packages and pandas-gbq/google packages which is forcing pandas-gbq of version 0.13.2 installed, which is not + # compatible with pandas 1.5.3 "snowflake-sqlalchemy>=1.2.0", "sqlalchemy-bigquery>=1.3.0", "s3fs", diff --git a/src/universal_transfer_operator/constants.py b/src/universal_transfer_operator/constants.py index b8fb5a9..647d145 100644 --- a/src/universal_transfer_operator/constants.py +++ b/src/universal_transfer_operator/constants.py @@ -99,3 +99,4 @@ def __repr__(self): LoadExistStrategy = Literal["replace", "append"] DEFAULT_CHUNK_SIZE = 1000000 ColumnCapitalization = Literal["upper", "lower", "original"] +DEFAULT_SCHEMA = "tmp_transfers" diff --git a/src/universal_transfer_operator/data_providers/__init__.py b/src/universal_transfer_operator/data_providers/__init__.py index 578ce77..d7dff96 100644 --- a/src/universal_transfer_operator/data_providers/__init__.py +++ b/src/universal_transfer_operator/data_providers/__init__.py @@ -5,14 +5,18 @@ from universal_transfer_operator.constants import TransferMode from universal_transfer_operator.data_providers.base import DataProviders from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Table from universal_transfer_operator.utils import TransferParameters, get_class_name DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING = { - "s3": "universal_transfer_operator.data_providers.filesystem.aws.s3", - "aws": "universal_transfer_operator.data_providers.filesystem.aws.s3", - "gs": "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs", - "google_cloud_platform": "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs", - "sftp": "universal_transfer_operator.data_providers.filesystem.sftp", + ("s3", File): "universal_transfer_operator.data_providers.filesystem.aws.s3", + ("aws", File): "universal_transfer_operator.data_providers.filesystem.aws.s3", + ("gs", File): "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs", + ("google_cloud_platform", File): "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs", + ("sqlite", Table): "universal_transfer_operator.data_providers.database.sqlite", + ("sftp", File): "universal_transfer_operator.data_providers.filesystem.sftp", + ("sqlite", Table): "universal_transfer_operator.data_providers.database.sqlite", } @@ -22,7 +26,7 @@ def create_dataprovider( transfer_mode: TransferMode = TransferMode.NONNATIVE, ) -> DataProviders: conn_type = BaseHook.get_connection(dataset.conn_id).conn_type - module_path = DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING[conn_type] + module_path = DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING[(conn_type, type(dataset))] module = importlib.import_module(module_path) class_name = get_class_name(module_ref=module, suffix="DataProvider") data_provider: DataProviders = getattr(module, class_name)( diff --git a/src/universal_transfer_operator/data_providers/base.py b/src/universal_transfer_operator/data_providers/base.py index 8b50053..c7e95d1 100644 --- a/src/universal_transfer_operator/data_providers/base.py +++ b/src/universal_transfer_operator/data_providers/base.py @@ -97,3 +97,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/data_providers/database/__init__.py b/src/universal_transfer_operator/data_providers/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/universal_transfer_operator/data_providers/database/base.py b/src/universal_transfer_operator/data_providers/database/base.py new file mode 100644 index 0000000..b516884 --- /dev/null +++ b/src/universal_transfer_operator/data_providers/database/base.py @@ -0,0 +1,615 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pandas as pd +import sqlalchemy + +if TYPE_CHECKING: # pragma: no cover + from sqlalchemy.engine.cursor import CursorResult + +import warnings + +import attr +from airflow.hooks.dbapi import DbApiHook +from pandas.io.sql import SQLDatabase +from sqlalchemy.sql import ClauseElement + +from universal_transfer_operator.constants import ( + DEFAULT_CHUNK_SIZE, + ColumnCapitalization, + LoadExistStrategy, + Location, + TransferMode, +) +from universal_transfer_operator.data_providers.base import DataProviders +from universal_transfer_operator.data_providers.filesystem import resolve_file_path_pattern +from universal_transfer_operator.data_providers.filesystem.base import FileStream +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SCHEMA +from universal_transfer_operator.universal_transfer_operator import TransferParameters +from universal_transfer_operator.utils import get_dataset_connection_type + + +class DatabaseDataProvider(DataProviders): + """DatabaseProviders represent all the DataProviders interactions with Databases.""" + + _create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {}" + _drop_table_statement: str = "DROP TABLE IF EXISTS {}" + _create_table_statement: str = "CREATE TABLE IF NOT EXISTS {} AS {}" + # Used to normalize the ndjson when appending fields in nested fields. + # Example - + # ndjson - {'a': {'b': 'val'}} + # the col names generated is 'a.b'. char '.' maybe an illegal char in some db's col name. + # Contains the illegal char and there replacement, where the value in + # illegal_column_name_chars[0] will be replaced by value in illegal_column_name_chars_replacement[0] + illegal_column_name_chars: list[str] = [] + illegal_column_name_chars_replacement: list[str] = [] + # In run_raw_sql operator decides if we want to return results directly or process them by handler provided + IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = False + NATIVE_PATHS: dict[Any, Any] = {} + DEFAULT_SCHEMA = SCHEMA + + def __init__( + self, + dataset: Table, + transfer_mode: TransferMode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def sql_type(self): + raise NotImplementedError + + @property + def hook(self) -> DbApiHook: + """Return an instance of the database-specific Airflow hook.""" + raise NotImplementedError + + @property + def connection(self) -> sqlalchemy.engine.base.Connection: + """Return a Sqlalchemy connection object for the given database.""" + return self.sqlalchemy_engine.connect() + + @property + def sqlalchemy_engine(self) -> sqlalchemy.engine.base.Engine: + """Return Sqlalchemy engine.""" + return self.hook.get_sqlalchemy_engine() # type: ignore[no-any-return] + + @property + def transport_params(self) -> dict | None: # skipcq: PYL-R0201 + """Get credentials required by smart open to access files""" + return None + + def run_sql( + self, + sql: str | ClauseElement = "", + parameters: dict | None = None, + **kwargs, + ) -> CursorResult: + """ + Return the results to running a SQL statement. + + Whenever possible, this method should be implemented using Airflow Hooks, + since this will simplify the integration with Async operators. + + :param sql: Contains SQL query to be run against database + :param parameters: Optional parameters to be used to render the query + :param autocommit: Optional autocommit flag + """ + if parameters is None: + parameters = {} + + if "sql_statement" in kwargs: # pragma: no cover + warnings.warn( + "`sql_statement` is deprecated and will be removed in future release" + "Please use `sql` param instead.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get("sql_statement") # type: ignore + + # We need to autocommit=True to make sure the query runs. This is done exclusively for SnowflakeDatabase's + # truncate method to reflect changes. + if isinstance(sql, str): + result = self.connection.execute( + sqlalchemy.text(sql).execution_options(autocommit=True), parameters + ) + else: + result = self.connection.execute(sql, parameters) + return result + + def columns_exist(self, table: Table, columns: list[str]) -> bool: + """ + Check that a list of columns exist in the given table. + + :param table: The table to check in. + :param columns: The columns to check. + + :returns: whether the columns exist in the table or not. + """ + sqla_table = self.get_sqla_table(table) + return all( + any(sqla_column.name == column for sqla_column in sqla_table.columns) for column in columns + ) + + def table_exists(self, table: Table) -> bool: + """ + Check if a table exists in the database. + + :param table: Details of the table we want to check that exists + """ + table_qualified_name = self.get_table_qualified_name(table) + inspector = sqlalchemy.inspect(self.sqlalchemy_engine) + return bool(inspector.dialect.has_table(self.connection, table_qualified_name)) + + def check_if_transfer_supported(self, source_dataset: Table) -> bool: + """ + Checks if the transfer is supported from source to destination based on source_dataset. + """ + source_connection_type = get_dataset_connection_type(source_dataset) + return Location(source_connection_type) in self.transfer_mapping + + def read(self): + """ ""Read the dataset and write to local reference location""" + raise NotImplementedError + + def write(self, source_ref: FileStream): + """Write the data from local reference location to the dataset""" + return self.load_file_to_table( + input_file=source_ref.actual_file, + output_table=self.dataset, + ) + + def load_data_from_source_natively(self, source_dataset: Table, destination_dataset: Dataset) -> None: + """ + Loads data from source dataset to the destination using data provider + """ + if not self.check_if_transfer_supported(source_dataset=source_dataset): + raise ValueError("Transfer not supported yet.") + + source_connection_type = get_dataset_connection_type(source_dataset) + method_name = self.LOAD_DATA_NATIVELY_FROM_SOURCE.get(source_connection_type) + if method_name: + transfer_method = self.__getattribute__(method_name) + return transfer_method( + source_dataset=source_dataset, + destination_dataset=destination_dataset, + ) + else: + raise ValueError(f"No transfer performed from {source_connection_type} to S3.") + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + # --------------------------------------------------------- + # Table metadata + # --------------------------------------------------------- + @staticmethod + def get_table_qualified_name(table: Table) -> str: # skipcq: PYL-R0201 + """ + Return table qualified name. This is Database-specific. + For instance, in Sqlite this is the table name. In Snowflake, however, it is the database, schema and table + + :param table: The table we want to retrieve the qualified name for. + """ + # Initially this method belonged to the Table class. + # However, in order to have an agnostic table class implementation, + # we are keeping all methods which vary depending on the database within the Database class. + if table.metadata and table.metadata.schema: + qualified_name = f"{table.metadata.schema}.{table.name}" + else: + qualified_name = table.name + return qualified_name + + @property + def default_metadata(self) -> Metadata: + """ + Extract the metadata available within the Airflow connection associated with + self.dataset.conn_id. + + :return: a Metadata instance + """ + raise NotImplementedError + + def populate_table_metadata(self, table: Table) -> Table: + """ + Given a table, check if the table has metadata. + If the metadata is missing, and the database has metadata, assign it to the table. + If the table schema was not defined by the end, retrieve the user-defined schema. + This method performs the changes in-place and also returns the table. + + :param table: Table to potentially have their metadata changed + :return table: Return the modified table + """ + if table.metadata and table.metadata.is_empty() and self.default_metadata: + table.metadata = self.default_metadata + if not table.metadata.schema: + table.metadata.schema = self.DEFAULT_SCHEMA + return table + + # --------------------------------------------------------- + # Table creation & deletion methods + # --------------------------------------------------------- + def create_table_using_columns(self, table: Table) -> None: + """ + Create a SQL table using the table columns. + + :param table: The table to be created. + """ + if not table.columns: + raise ValueError("To use this method, table.columns must be defined") + + metadata = table.sqlalchemy_metadata + sqlalchemy_table = sqlalchemy.Table(table.name, metadata, *table.columns) + metadata.create_all(self.sqlalchemy_engine, tables=[sqlalchemy_table]) + + def is_native_autodetect_schema_available( # skipcq: PYL-R0201 + self, file: File # skipcq: PYL-W0613 + ) -> bool: + """ + Check if native auto detection of schema is available. + + :param file: File used to check the file type of to decide + whether there is a native auto detection available for it. + """ + return False + + def create_table_using_native_schema_autodetection( + self, + table: Table, + file: File, + ) -> None: + """ + Create a SQL table, automatically inferring the schema using the given file via native database support. + + :param table: The table to be created. + :param file: File used to infer the new table columns. + """ + raise NotImplementedError("Missing implementation of native schema autodetection.") + + def create_table_using_schema_autodetection( + self, + table: Table, + file: File | None = None, + dataframe: pd.DataFrame | None = None, + columns_names_capitalization: ColumnCapitalization = "original", # skipcq + ) -> None: + """ + Create a SQL table, automatically inferring the schema using the given file. + + :param table: The table to be created. + :param file: File used to infer the new table columns. + :param dataframe: Dataframe used to infer the new table columns if there is no file + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + """ + if file is None: + if dataframe is None: + raise ValueError( + "File or Dataframe is required for creating table using schema autodetection" + ) + source_dataframe = dataframe + else: + source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT) + + db = SQLDatabase(engine=self.sqlalchemy_engine) + db.prep_table( + source_dataframe, + table.name.lower(), + schema=table.metadata.schema, + if_exists="replace", + index=False, + ) + + def create_table( + self, + table: Table, + file: File | None = None, + dataframe: pd.DataFrame | None = None, + columns_names_capitalization: ColumnCapitalization = "original", + use_native_support: bool = True, + ) -> None: + """ + Create a table either using its explicitly defined columns or inferring + it's columns from a given file. + + :param table: The table to be created + :param file: (optional) File used to infer the table columns. + :param dataframe: (optional) Dataframe used to infer the new table columns if there is no file + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + """ + if table.columns: + self.create_table_using_columns(table) + elif use_native_support and file and self.is_native_autodetect_schema_available(file): + self.create_table_using_native_schema_autodetection(table, file) + else: + self.create_table_using_schema_autodetection( + table, + file=file, + dataframe=dataframe, + columns_names_capitalization=columns_names_capitalization, + ) + + def create_table_from_select_statement( + self, + statement: str, + target_table: Table, + parameters: dict | None = None, + ) -> None: + """ + Export the result rows of a query statement into another table. + + :param statement: SQL query statement + :param target_table: Destination table where results will be recorded. + :param parameters: (Optional) parameters to be used to render the SQL query + """ + statement = self._create_table_statement.format( + self.get_table_qualified_name(target_table), statement + ) + self.run_sql(statement, parameters) + + def drop_table(self, table: Table) -> None: + """ + Delete a SQL table, if it exists. + + :param table: The table to be deleted. + """ + statement = self._drop_table_statement.format(self.get_table_qualified_name(table)) + self.run_sql(statement) + + # --------------------------------------------------------- + # Table load methods + # --------------------------------------------------------- + + def create_schema_and_table_if_needed( + self, + table: Table, + file: File, + normalize_config: dict | None = None, + columns_names_capitalization: ColumnCapitalization = "original", + if_exists: LoadExistStrategy = "replace", + use_native_support: bool = True, + ): + """ + Checks if the autodetect schema exists for native support else creates the schema and table + :param table: Table to create + :param file: File path and conn_id for object stores + :param normalize_config: pandas json_normalize params config + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + :param if_exists: Overwrite file if exists + :param use_native_support: Use native support for data transfer if available on the destination + """ + is_schema_autodetection_supported = self.check_schema_autodetection_is_supported(source_file=file) + is_file_pattern_based_schema_autodetection_supported = ( + self.check_file_pattern_based_schema_autodetection_is_supported(source_file=file) + ) + if if_exists == "replace": + self.drop_table(table) + if use_native_support and is_schema_autodetection_supported and not file.is_pattern(): + return + if ( + use_native_support + and file.is_pattern() + and is_schema_autodetection_supported + and is_file_pattern_based_schema_autodetection_supported + ): + return + + self.create_schema_if_needed(table.metadata.schema) + if if_exists == "replace" or not self.table_exists(table): + files = resolve_file_path_pattern( + file, + normalize_config=normalize_config, + filetype=file.type.name, + transfer_params=self.transfer_params, + transfer_mode=self.transfer_mode, + ) + self.create_table( + table, + # We only use the first file for inferring the table schema + files[0], + columns_names_capitalization=columns_names_capitalization, + use_native_support=use_native_support, + ) + + def fetch_all_rows(self, table: Table, row_limit: int = -1) -> list: + """ + Fetches all rows for a table and returns as a list. This is needed because some + databases have different cursors that require different methods to fetch rows + + :param row_limit: Limit the number of rows returned, by default return all rows. + :param table: The table metadata needed to fetch the rows + :return: a list of rows + """ + statement = f"SELECT * FROM {self.get_table_qualified_name(table)}" # skipcq: BAN-B608 + if row_limit > -1: + statement = statement + f" LIMIT {row_limit}" # skipcq: BAN-B608 + response = self.run_sql(statement) + return response.fetchall() # type: ignore + + def load_file_to_table( + self, + input_file: File, + output_table: Table, + normalize_config: dict | None = None, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + columns_names_capitalization: ColumnCapitalization = "original", + **kwargs, + ): + """ + Load content of multiple files in output_table. + Multiple files are sourced from the file path, which can also be path pattern. + + :param input_file: File path and conn_id for object stores + :param output_table: Table to create + :param if_exists: Overwrite file if exists + :param chunk_size: Specify the number of records in each batch to be written at a time + :param use_native_support: Use native support for data transfer if available on the destination + :param normalize_config: pandas json_normalize params config + :param native_support_kwargs: kwargs to be used by method involved in native support flow + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + :param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer + """ + normalize_config = normalize_config or {} + + self.create_schema_and_table_if_needed( + file=input_file, + table=output_table, + columns_names_capitalization=columns_names_capitalization, + if_exists=if_exists, + normalize_config=normalize_config, + ) + self.load_file_to_table_using_pandas( + input_file=input_file, + output_table=output_table, + normalize_config=normalize_config, + if_exists="append", + chunk_size=chunk_size, + ) + + def load_file_to_table_using_pandas( + self, + input_file: File, + output_table: Table, + normalize_config: dict | None = None, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + ): + input_files = resolve_file_path_pattern( + file=input_file, + normalize_config=normalize_config, + filetype=input_file.type.name, + transfer_params=self.transfer_params, + transfer_mode=self.transfer_mode, + ) + + for file in input_files: + self.load_pandas_dataframe_to_table( + self.get_dataframe_from_file(file), + output_table, + chunk_size=chunk_size, + if_exists=if_exists, + ) + + def load_pandas_dataframe_to_table( + self, + source_dataframe: pd.DataFrame, + target_table: Table, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> None: + """ + Create a table with the dataframe's contents. + If the table already exists, append or replace the content, depending on the value of `if_exists`. + + :param source_dataframe: Local or remote filepath + :param target_table: Table in which the file will be loaded + :param if_exists: Strategy to be used in case the target table already exists. + :param chunk_size: Specify the number of rows in each batch to be written at a time. + """ + self._assert_not_empty_df(source_dataframe) + + source_dataframe.to_sql( + self.get_table_qualified_name(target_table), + con=self.sqlalchemy_engine, + if_exists=if_exists, + chunksize=chunk_size, + method="multi", + index=False, + ) + + @staticmethod + def _assert_not_empty_df(df): + """Raise error if dataframe empty + + param df: A dataframe + """ + if df.empty: + raise ValueError("Can't load empty dataframe") + + @staticmethod + def get_dataframe_from_file(file: File): + """ + Get pandas dataframe file. We need export_to_dataframe() for Biqqery,Snowflake and Redshift except for Postgres. + For postgres we are overriding this method and using export_to_dataframe_via_byte_stream(). + export_to_dataframe_via_byte_stream copies files in a buffer and then use that buffer to ingest data. + With this approach we have significant performance boost for postgres. + + :param file: File path and conn_id for object stores + """ + + return file.export_to_dataframe() + + def check_schema_autodetection_is_supported( # skipcq: PYL-R0201 + self, source_file: File # skipcq: PYL-W0613 + ) -> bool: + """ + Checks if schema autodetection is handled natively by the database. Return False by default. + + :param source_file: File from which we need to transfer data + """ + return False + + def check_file_pattern_based_schema_autodetection_is_supported( # skipcq: PYL-R0201 + self, source_file: File # skipcq: PYL-W0613 + ) -> bool: + """ + Checks if schema autodetection is handled natively by the database for file + patterns and prefixes. Return False by default. + + :param source_file: File from which we need to transfer data + """ + return False + + def row_count(self, table: Table): + """ + Returns the number of rows in a table. + + :param table: table to count + :return: The number of rows in the table + """ + result = self.run_sql( + f"select count(*) from {self.get_table_qualified_name(table)}" # skipcq: BAN-B608 + ).scalar() + return result diff --git a/src/universal_transfer_operator/data_providers/database/sqlite.py b/src/universal_transfer_operator/data_providers/database/sqlite.py new file mode 100644 index 0000000..4470b49 --- /dev/null +++ b/src/universal_transfer_operator/data_providers/database/sqlite.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import socket + +import attr +from airflow.providers.sqlite.hooks.sqlite import SqliteHook +from sqlalchemy import MetaData as SqlaMetaData, create_engine +from sqlalchemy.engine.base import Engine +from sqlalchemy.sql.schema import Table as SqlaTable + +from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider, FileStream +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.universal_transfer_operator import TransferParameters + + +class SqliteDataProvider(DatabaseDataProvider): + """SqliteDataProvider represent all the DataProviders interactions with Sqlite Databases.""" + + def __init__( + self, + dataset: Table, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def sql_type(self) -> str: + return "sqlite" + + @property + def hook(self) -> SqliteHook: + """Retrieve Airflow hook to interface with the Sqlite database.""" + return SqliteHook(sqlite_conn_id=self.dataset.conn_id) + + @property + def sqlalchemy_engine(self) -> Engine: + """Return SQAlchemy engine.""" + # Airflow uses sqlite3 library and not SqlAlchemy for SqliteHook + # and it only uses the hostname directly. + airflow_conn = self.hook.get_connection(self.dataset.conn_id) + return create_engine(f"sqlite:///{airflow_conn.host}") + + @property + def default_metadata(self) -> Metadata: + """Since Sqlite does not use Metadata, we return an empty Metadata instances.""" + return Metadata() + + def read(self): + """ ""Read the dataset and write to local reference location""" + raise NotImplementedError + + def write(self, source_ref: FileStream): + """Write the data from local reference location to the dataset""" + return self.load_file_to_table( + input_file=source_ref.actual_file, + output_table=self.dataset, + ) + + # --------------------------------------------------------- + # Table metadata + # --------------------------------------------------------- + @staticmethod + def get_table_qualified_name(table: Table) -> str: + """ + Return the table qualified name. + + :param table: The table we want to retrieve the qualified name for. + """ + return str(table.name) + + def populate_table_metadata(self, table: Table) -> Table: + """ + Since SQLite does not have a concept of databases or schemas, we just return the table as is, + without any modifications. + """ + table.conn_id = table.conn_id or self.dataset.conn_id + return table + + def create_schema_if_needed(self, schema: str | None) -> None: + """ + Since SQLite does not have schemas, we do not need to set a schema here. + """ + + def schema_exists(self, schema: str) -> bool: # skipcq PYL-W0613,PYL-R0201 + """ + Check if a schema exists. We return false for sqlite since sqlite does not have schemas + """ + return False + + def get_sqla_table(self, table: Table) -> SqlaTable: + """ + Return SQLAlchemy table instance + + :param table: Astro Table to be converted to SQLAlchemy table instance + """ + return SqlaTable(table.name, SqlaMetaData(), autoload_with=self.sqlalchemy_engine) + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: /tmp/local.db.table_name + """ + conn = self.hook.get_connection(self.dataset.conn_id) + return f"{conn.host}.{self.dataset.name}" + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: file://127.0.0.1:22 + """ + conn = self.hook.get_connection(self.dataset.conn_id) + port = conn.port or 22 + return f"file://{socket.gethostbyname(socket.gethostname())}:{port}" + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name()}" diff --git a/src/universal_transfer_operator/data_providers/filesystem/__init__.py b/src/universal_transfer_operator/data_providers/filesystem/__init__.py index e69de29..8446055 100644 --- a/src/universal_transfer_operator/data_providers/filesystem/__init__.py +++ b/src/universal_transfer_operator/data_providers/filesystem/__init__.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from universal_transfer_operator.constants import FileType, TransferMode +from universal_transfer_operator.data_providers import create_dataprovider +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.utils import TransferParameters + + +def resolve_file_path_pattern( + file: File, + filetype: FileType | None = None, + normalize_config: dict | None = None, + transfer_params: TransferParameters = None, + transfer_mode: TransferMode = TransferMode.NONNATIVE, +) -> list[File]: + """get file objects by resolving path_pattern from local/object stores + path_pattern can be + 1. local location - glob pattern + 2. s3/gcs location - prefix + + :param file: File dataset object + :param filetype: constant to provide an explicit file type + :param normalize_config: parameters in dict format of pandas json_normalize() function + :param transfer_params: kwargs to be used by method involved in transfer flow. + :param transfer_mode: Use transfer_mode TransferMode; native, non-native or thirdparty. + """ + location = create_dataprovider( + dataset=file, + transfer_params=transfer_params, + transfer_mode=transfer_mode, + ) + files = [] + for path in location.paths: + if not path.endswith("/"): + file = File( + path=path, + conn_id=file.conn_id, + filetype=filetype, + normalize_config=normalize_config, + ) + files.append(file) + if len(files) == 0: + raise FileNotFoundError(f"File(s) not found for path/pattern '{file.path}'") + return files diff --git a/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py b/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py index 75f149b..46f4a6b 100644 --- a/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py +++ b/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py @@ -174,3 +174,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/data_providers/filesystem/base.py b/src/universal_transfer_operator/data_providers/filesystem/base.py index fed5459..8309a48 100644 --- a/src/universal_transfer_operator/data_providers/filesystem/base.py +++ b/src/universal_transfer_operator/data_providers/filesystem/base.py @@ -28,6 +28,7 @@ class TempFile: class FileStream: remote_obj_buffer: io.IOBase actual_filename: Path + actual_file: File class BaseFilesystemProviders(DataProviders): @@ -90,7 +91,9 @@ def read_using_smart_open(self): files = self.paths for file in files: yield FileStream( - remote_obj_buffer=self._convert_remote_file_to_byte_stream(file), actual_filename=file + remote_obj_buffer=self._convert_remote_file_to_byte_stream(file), + actual_filename=file, + actual_file=self.dataset, ) def _convert_remote_file_to_byte_stream(self, file: str) -> io.IOBase: @@ -184,3 +187,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py b/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py index 781ff68..709baa5 100644 --- a/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py +++ b/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py @@ -59,9 +59,10 @@ def transport_params(self) -> dict: def paths(self) -> list[str]: """Resolve GS file paths with prefix""" url = urlparse(self.dataset.path) + prefix = url.path[1:] prefixes = self.hook.list( bucket_name=self.bucket_name, # type: ignore - prefix=self.prefix, + prefix=prefix, delimiter=self.delimiter, ) paths = [urlunparse((url.scheme, url.netloc, keys, "", "", "")) for keys in prefixes] @@ -179,3 +180,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/datasets/file/base.py b/src/universal_transfer_operator/datasets/file/base.py index 2e8ee0b..a3e850e 100644 --- a/src/universal_transfer_operator/datasets/file/base.py +++ b/src/universal_transfer_operator/datasets/file/base.py @@ -32,6 +32,12 @@ class File(Dataset): uri: str = field(init=False) extra: dict = field(init=True, factory=dict) + @property + def location(self): + from universal_transfer_operator.data_providers import create_dataprovider + + return create_dataprovider(dataset=self) + @property def size(self) -> int: """ diff --git a/src/universal_transfer_operator/datasets/table.py b/src/universal_transfer_operator/datasets/table.py index 68aedb8..45f48e8 100644 --- a/src/universal_transfer_operator/datasets/table.py +++ b/src/universal_transfer_operator/datasets/table.py @@ -1,12 +1,17 @@ from __future__ import annotations -from urllib.parse import urlparse +import random +import string +from typing import Any from attr import define, field, fields_dict -from sqlalchemy import Column +from sqlalchemy import Column, MetaData from universal_transfer_operator.datasets.base import Dataset +MAX_TABLE_NAME_LENGTH = 62 +TEMP_PREFIX = "_tmp" + @define class Metadata: @@ -50,42 +55,125 @@ class Table(Dataset): uri: str = field(init=False) extra: dict = field(init=True, factory=dict) - @property - def sql_type(self): - raise NotImplementedError - def exists(self): """Check if the table exists or not""" raise NotImplementedError - def __str__(self) -> str: - return self.path + def _create_unique_table_name(self, prefix: str = "") -> str: + """ + If a table is instantiated without a name, create a unique table for it. + This new name should be compatible with all supported databases. + """ + schema_length = len((self.metadata and self.metadata.schema) or "") + 1 + prefix_length = len(prefix) + + unique_id = random.choice(string.ascii_lowercase) + "".join( + random.choice(string.ascii_lowercase + string.digits) + for _ in range(MAX_TABLE_NAME_LENGTH - schema_length - prefix_length) + ) + if prefix: + unique_id = f"{prefix}{unique_id}" - def __hash__(self) -> int: - return hash((self.path, self.conn_id)) + return unique_id - def dataset_scheme(self): + def create_similar_table(self) -> Table: """ - Return the scheme based on path + Create a new table with a unique name but with the same metadata. """ - parsed = urlparse(self.path) - return parsed.scheme + return Table( # type: ignore + name=self._create_unique_table_name(), + conn_id=self.conn_id, + metadata=self.metadata, + ) + + @property + def sqlalchemy_metadata(self) -> MetaData: + """Return the Sqlalchemy metadata for the given table.""" + if self.metadata and self.metadata.schema: + alchemy_metadata = MetaData(schema=self.metadata.schema) + else: + alchemy_metadata = MetaData() + return alchemy_metadata - def dataset_namespace(self): + @property + def row_count(self) -> Any: + """ + Return the row count of table. """ - The namespace of a dataset can be combined to form a URI (scheme:[//authority]path) + from universal_transfer_operator.data_providers import create_dataprovider + + database_provider = create_dataprovider(dataset=self) + return database_provider.row_count(self) - Namespace = scheme:[//authority] (the dataset) + @property + def sql_type(self) -> Any: + from universal_transfer_operator.data_providers import create_dataprovider + + if self.conn_id: + return create_dataprovider(dataset=self).sql_type + + def to_json(self): + return { + "class": "Table", + "name": self.name, + "metadata": { + "schema": self.metadata.schema, + "database": self.metadata.database, + }, + "temp": self.temp, + "conn_id": self.conn_id, + } + + @classmethod + def from_json(cls, obj: dict): + return Table( + name=obj["name"], + metadata=Metadata(**obj["metadata"]), + temp=obj["temp"], + conn_id=obj["conn_id"], + ) + + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ - parsed = urlparse(self.path) - namespace = f"{self.dataset_scheme()}://{parsed.netloc}" - return namespace + from universal_transfer_operator.data_providers import create_dataprovider - def dataset_name(self): + database_provider = create_dataprovider(dataset=self) + return database_provider.openlineage_dataset_name(table=self) + + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ - The name of a dataset can be combined to form a URI (scheme:[//authority]path) + from universal_transfer_operator.data_providers import create_dataprovider - Name = path (the datasets) + database_provider = create_dataprovider(dataset=self) + return database_provider.openlineage_dataset_namespace() + + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ - parsed = urlparse(self.path) - return parsed.path if self.path else self.name + from universal_transfer_operator.data_providers import create_dataprovider + + database_provider = create_dataprovider(dataset=self) + return f"{database_provider.openlineage_dataset_uri(table=self)}" + + @uri.default + def _path_to_dataset_uri(self) -> str: + """Build a URI to be passed to Dataset obj introduced in Airflow 2.4""" + from urllib.parse import urlencode, urlparse + + path = f"astro://{self.conn_id}@" + db_extra = {"table": self.name} + if self.metadata.schema: + db_extra["schema"] = self.metadata.schema + if self.metadata.database: + db_extra["database"] = self.metadata.database + parsed_url = urlparse(url=path) + new_parsed_url = parsed_url._replace(query=urlencode(db_extra)) + return new_parsed_url.geturl() diff --git a/src/universal_transfer_operator/settings.py b/src/universal_transfer_operator/settings.py index 280d241..57a4856 100644 --- a/src/universal_transfer_operator/settings.py +++ b/src/universal_transfer_operator/settings.py @@ -4,6 +4,8 @@ from airflow.version import version as airflow_version from packaging.version import Version +from universal_transfer_operator.constants import DEFAULT_SCHEMA + # Section name for universal transfer operator configs in airflow.cfg SECTION_KEY = "universal_transfer_operator" @@ -23,3 +25,20 @@ # We only need PandasDataframe and other custom serialization and deserialization # if Airflow >= 2.5 and Pickling is not enabled and neither Custom XCom backend is used NEED_CUSTOM_SERIALIZATION = AIRFLOW_25_PLUS and IS_BASE_XCOM_BACKEND and not ENABLE_XCOM_PICKLING + +# Bigquery list of all the valid locations: https://cloud.google.com/bigquery/docs/locations +DEFAULT_BIGQUERY_SCHEMA_LOCATION = "us" +SCHEMA = conf.get(SECTION_KEY, "sql_schema", fallback=DEFAULT_SCHEMA) +POSTGRES_SCHEMA = conf.get(SECTION_KEY, "postgres_default_schema", fallback=SCHEMA) +BIGQUERY_SCHEMA = conf.get(SECTION_KEY, "bigquery_default_schema", fallback=SCHEMA) +SNOWFLAKE_SCHEMA = conf.get(SECTION_KEY, "snowflake_default_schema", fallback=SCHEMA) +REDSHIFT_SCHEMA = conf.get(SECTION_KEY, "redshift_default_schema", fallback=SCHEMA) +MSSQL_SCHEMA = conf.get(SECTION_KEY, "mssql_default_schema", fallback=SCHEMA) + +BIGQUERY_SCHEMA_LOCATION = conf.get( + SECTION_KEY, "bigquery_dataset_location", fallback=DEFAULT_BIGQUERY_SCHEMA_LOCATION +) + +LOAD_TABLE_AUTODETECT_ROWS_COUNT = conf.getint( + section=SECTION_KEY, key="load_table_autodetect_rows_count", fallback=1000 +) diff --git a/test-connections.yaml b/test-connections.yaml new file mode 100644 index 0000000..90a4f34 --- /dev/null +++ b/test-connections.yaml @@ -0,0 +1,142 @@ +connections: + - conn_id: postgres_conn + conn_type: postgres + host: localhost + schema: + login: postgres + password: postgres + port: 5432 + extra: + - conn_id: postgres_conn_pagila + conn_type: postgres + host: localhost + schema: pagila + login: postgres + password: postgres + port: 5432 + extra: + - conn_id: postgres_benchmark_conn + conn_type: postgres + host: postgres + schema: + login: postgres + password: postgres + port: 5432 + extra: + - conn_id: snowflake_conn + conn_type: snowflake + host: https://gp21411.us-east-1.snowflakecomputing.com + port: 443 + login: $SNOWFLAKE_ACCOUNT_NAME + password: $SNOWFLAKE_PASSWORD + schema: ASTROFLOW_CI + extra: + account: "gp21411" + region: "us-east-1" + role: "AIRFLOW_TEST_USER" + warehouse: ROBOTS + database: SANDBOX + - conn_id: snowflake_conn_1 + conn_type: snowflake + host: https://gp21411.us-east-1.snowflakecomputing.com + port: 443 + login: $SNOWFLAKE_ACCOUNT_NAME + password: $SNOWFLAKE_PASSWORD + extra: + account: "gp21411" + region: "us-east-1" + role: "AIRFLOW_TEST_USER" + warehouse: ROBOTS + - conn_id: bigquery + conn_type: bigquery + description: null + extra: + project: "astronomer-dag-authoring" + host: null + login: null + password: null + port: null + schema: null + - conn_id: sqlite_conn + conn_type: sqlite + host: /tmp/sqlite.db + schema: + login: + password: + - conn_id: gcp_conn + conn_type: google_cloud_platform + description: null + extra: null + - conn_id: gdrive_conn + conn_type: google_cloud_platform + description: connection to test google drive as file location + extra: '{"extra__google_cloud_platform__scope":"https://www.googleapis.com/auth/drive.readonly"}' + - conn_id: aws_conn + conn_type: aws + description: null + extra: null + - conn_id: redshift_conn + conn_type: redshift + schema: $REDSHIFT_DATABASE + host: $REDSHIFT_HOST + port: 5439 + login: $REDSHIFT_USERNAME + password: $REDSHIFT_PASSWORD + - conn_id: s3_conn_benchmark + conn_type: aws + description: null + extra: + aws_access_key_id: $AWS_ACCESS_KEY_ID + aws_secret_access_key: $AWS_SECRET_ACCESS_KEY + - conn_id: databricks_conn + conn_type: delta + host: https://dbc-9c390870-65ef.cloud.databricks.com/ + password: $DATABRICKS_TOKEN + extra: + http_path: /sql/1.0/warehouses/cf414a2206dfb397 + - conn_id: wasb_default + conn_type: wasb + description: null + extra: + connection_string: $AZURE_WASB_CONN_STRING + - conn_id: gcp_conn_project + conn_type: google_cloud_platform + description: null + extra: + project: "astronomer-dag-authoring" + project_id: "astronomer-dag-authoring" + - conn_id: sftp_conn + conn_type: sftp + host: localhost + login: foo + password: pass + port: 2222 + - conn_id: ftp_conn + conn_type: ftp + host: $SFTP_HOSTNAME + login: $SFTP_USERNAME + password: $SFTP_PASSWORD + port: 21 + extra: + - conn_id: mssql_conn + conn_type: mssql + host: $MSSQL_HOST + login: $MSSQL_LOGIN + password: $MSSQL_PASSWORD + port: 1433 + schema: $MSSQL_DB + extra: + - conn_id: duckdb_conn + conn_type: duckdb + host: /tmp/ciduckdb.duckdb + schema: + login: + password: + - conn_id: minio_conn + conn_type: aws + description: null + extra: + aws_access_key_id: "ROOTNAME" + aws_secret_access_key: "CHANGEME123" + endpoint_url: "http://127.0.0.1:9000" + minio: True diff --git a/tests/conftest.py b/tests/conftest.py index bceda4c..d09fa41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import logging import os +from copy import deepcopy import pytest import yaml @@ -9,9 +10,17 @@ from airflow.utils.session import create_session from utils.test_utils import create_unique_str +from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.data_providers import create_dataprovider +from universal_transfer_operator.datasets.table import Table + DEFAULT_DATE = timezone.datetime(2016, 1, 1) UNIQUE_HASH_SIZE = 16 +DATASET_NAME_TO_CONN_ID = { + "SqliteDataProvider": "sqlite_default", +} + @pytest.fixture def sample_dag(): @@ -45,3 +54,49 @@ def create_database_connections(): conn.conn_id, ) session.add(conn) + + +@pytest.fixture +def dataset_table_fixture(request): + """ + Given request.param in the format: + { + "database": Database.SQLITE, # mandatory, may be any supported database + "table": astro.sql.tables.Table(), # optional, will create a table unless it is passed + "file": "" # optional, File() instance to be used to load data to the table. + } + This fixture yields the following during setup: + (database, table) + Example: + (astro.databases.sqlite.SqliteDatabase(), Table()) + If the table exists, it is deleted during the tests setup and tear down. + The table will only be created during setup if request.param contains the `file` parameter. + """ + # We deepcopy the request param dictionary as we modify the table item directly. + params = deepcopy(request.param) + + dataset_name = params["dataset"] + user_table = params.get("table", None) + transfer_mode = params.get("transfer_mode", TransferMode.NONNATIVE) + conn_id = DATASET_NAME_TO_CONN_ID[dataset_name] + if user_table and user_table.conn_id: + conn_id = user_table.conn_id + + table = user_table or Table(conn_id=conn_id) + if not table.conn_id: + table.conn_id = conn_id + + dp = create_dataprovider(dataset=table, transfer_mode=transfer_mode) + + if not table.name: + # We create a unique table name to make the name unique across runs + table.name = create_unique_str(UNIQUE_HASH_SIZE) + file = params.get("file") + + dp.populate_table_metadata(table) + dp.create_schema_if_needed(table.metadata.schema) + + if file: + dp.load_file_to_table(file, table) + yield dp, table + dp.drop_table(table) diff --git a/tests/test_data_provider/test_data_provider.py b/tests/test_data_provider/test_data_provider.py index d16bd2a..811750c 100644 --- a/tests/test_data_provider/test_data_provider.py +++ b/tests/test_data_provider/test_data_provider.py @@ -1,10 +1,12 @@ import pytest from universal_transfer_operator.data_providers import create_dataprovider +from universal_transfer_operator.data_providers.database.sqlite import SqliteDataProvider from universal_transfer_operator.data_providers.filesystem.aws.s3 import S3DataProvider from universal_transfer_operator.data_providers.filesystem.google.cloud.gcs import GCSDataProvider from universal_transfer_operator.data_providers.filesystem.sftp import SFTPDataProvider from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Table @pytest.mark.parametrize( @@ -13,6 +15,7 @@ {"dataset": File("s3://astro-sdk-test/uto/", conn_id="aws_default"), "expected": S3DataProvider}, {"dataset": File("gs://uto-test/uto/", conn_id="google_cloud_default"), "expected": GCSDataProvider}, {"dataset": File("sftp://upload/sample.csv", conn_id="sftp_default"), "expected": SFTPDataProvider}, + {"dataset": Table("some_table", conn_id="sqlite_default"), "expected": SqliteDataProvider}, ], ids=lambda d: d["dataset"].conn_id, ) diff --git a/tests/test_data_provider/test_database/test_base.py b/tests/test_data_provider/test_database/test_base.py new file mode 100644 index 0000000..9bd0224 --- /dev/null +++ b/tests/test_data_provider/test_database/test_base.py @@ -0,0 +1,73 @@ +import pathlib + +import pytest +from pandas import DataFrame + +from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Table + +CWD = pathlib.Path(__file__).parent + + +class DatabaseDataProviderSubclass(DatabaseDataProvider): + pass + + +def test_openlineage_database_dataset_namespace(): + """ + Test the open lineage dataset namespace for base class + """ + db = DatabaseDataProviderSubclass(dataset=Table(name="test"), transfer_mode=TransferMode.NONNATIVE) + with pytest.raises(NotImplementedError): + db.openlineage_dataset_namespace() + + +def test_openlineage_database_dataset_name(): + """ + Test the open lineage dataset names for the base class + """ + db = DatabaseDataProviderSubclass(dataset=Table(name="test"), transfer_mode=TransferMode.NONNATIVE) + with pytest.raises(NotImplementedError): + db.openlineage_dataset_name(table=Table) + + +def test_subclass_missing_not_implemented_methods_raise_exception(): + db = DatabaseDataProviderSubclass(dataset=Table(name="test"), transfer_mode=TransferMode.NONNATIVE) + with pytest.raises(NotImplementedError): + db.hook + + with pytest.raises(NotImplementedError): + db.sqlalchemy_engine + + with pytest.raises(NotImplementedError): + db.connection + + with pytest.raises(NotImplementedError): + db.default_metadata + + with pytest.raises(NotImplementedError): + db.run_sql("SELECT * FROM inexistent_table") + + +def test_create_table_using_native_schema_autodetection_not_implemented(): + db = DatabaseDataProviderSubclass(dataset=Table(name="test"), transfer_mode=TransferMode.NONNATIVE) + with pytest.raises(NotImplementedError): + db.create_table_using_native_schema_autodetection(table=Table(), file=File(path="s3://bucket/key")) + + +def test_subclass_missing_load_pandas_dataframe_to_table_raises_exception(): + db = DatabaseDataProviderSubclass(dataset=Table(name="test"), transfer_mode=TransferMode.NONNATIVE) + table = Table() + df = DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + with pytest.raises(NotImplementedError): + db.load_pandas_dataframe_to_table(df, table) + + +def test_create_table_using_columns_raises_exception(): + db = DatabaseDataProviderSubclass(dataset=Table(name="test"), transfer_mode=TransferMode.NONNATIVE) + table = Table() + with pytest.raises(ValueError) as exc_info: + db.create_table_using_columns(table) + assert exc_info.match("To use this method, table.columns must be defined") diff --git a/tests/test_data_provider/test_database/test_sqlite.py b/tests/test_data_provider/test_database/test_sqlite.py new file mode 100644 index 0000000..78411bc --- /dev/null +++ b/tests/test_data_provider/test_database/test_sqlite.py @@ -0,0 +1,164 @@ +import pathlib + +import pandas as pd +import pytest +import sqlalchemy +from airflow.hooks.base import BaseHook + +from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.data_providers.database.sqlite import SqliteDataProvider +from universal_transfer_operator.datasets.table import Table + +CWD = pathlib.Path(__file__).parent + +DEFAULT_CONN_ID = "sqlite_default" +CUSTOM_CONN_ID = "sqlite_conn" +SUPPORTED_CONN_IDS = [DEFAULT_CONN_ID, CUSTOM_CONN_ID] + + +@pytest.mark.integration +@pytest.mark.parametrize( + "conn_id,expected_db_path", + [ + ( + DEFAULT_CONN_ID, + BaseHook.get_connection(DEFAULT_CONN_ID).host, + ), # Linux and MacOS have different hosts + (CUSTOM_CONN_ID, "/tmp/sqlite.db"), + ], + ids=SUPPORTED_CONN_IDS, +) +def test_sqlite_sqlalchemy_engine(conn_id, expected_db_path): + """Confirm that the SQLAlchemy is created successfully and verify DB path.""" + dp = SqliteDataProvider( + dataset=Table("some_table", conn_id=conn_id), transfer_mode=TransferMode.NONNATIVE + ) + engine = dp.sqlalchemy_engine + assert isinstance(engine, sqlalchemy.engine.base.Engine) + assert engine.url.database == expected_db_path + + +@pytest.mark.integration +def test_sqlite_run_sql_with_sqlalchemy_text(): + """Run a SQL statement using SQLAlchemy text""" + statement = sqlalchemy.text("SELECT 1 + 1;") + dp = SqliteDataProvider( + dataset=Table("some_table", conn_id="sqlite_default"), transfer_mode=TransferMode.NONNATIVE + ) + response = dp.run_sql(statement) + assert response.first()[0] == 2 + + +@pytest.mark.integration +def test_sqlite_run_sql(): + """Run a SQL statement using plain string.""" + statement = "SELECT 1 + 1;" + dp = SqliteDataProvider( + dataset=Table("some_table", conn_id="sqlite_default"), transfer_mode=TransferMode.NONNATIVE + ) + response = dp.run_sql(statement) + assert response.first()[0] == 2 + + +@pytest.mark.integration +def test_sqlite_run_sql_with_parameters(): + """Test running a SQL query using SQLAlchemy templating engine""" + statement = "SELECT 1 + :value;" + dp = SqliteDataProvider( + dataset=Table("some_table", conn_id="sqlite_default"), transfer_mode=TransferMode.NONNATIVE + ) + response = dp.run_sql(statement, parameters={"value": 1}) + assert response.first()[0] == 2 + + +@pytest.mark.integration +def test_table_exists_raises_exception(): + """Raise an exception when checking for a non-existent table""" + dp = SqliteDataProvider( + dataset=Table("some_table", conn_id="sqlite_default"), transfer_mode=TransferMode.NONNATIVE + ) + assert not dp.table_exists(Table(name="inexistent-table")) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "dataset_table_fixture", + [ + { + "dataset": "SqliteDataProvider", + "table": Table( + columns=[ + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("name", sqlalchemy.String(60), nullable=False, key="name"), + ] + ), + }, + ], + indirect=True, + ids=["sqlite"], +) +def test_sqlite_create_table_with_columns(dataset_table_fixture): + """Create a table using specific columns and types""" + dp, table = dataset_table_fixture + + statement = f"PRAGMA table_info({table.name});" + response = dp.run_sql(statement) + assert len(response.fetchall()) == 0 + + dp.create_table(table) + response = dp.run_sql(statement) + rows = response.fetchall() + assert len(rows) == 2 + assert rows[0] == (0, "id", "INTEGER", 1, None, 1) + assert rows[1] == (1, "name", "VARCHAR(60)", 1, None, 0) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "dataset_table_fixture", + [ + { + "dataset": "SqliteDataProvider", + }, + ], + indirect=True, + ids=["sqlite"], +) +def test_sqlite_create_table_autodetection_without_file(dataset_table_fixture): + """Create a table using specific columns and types""" + dp, table = dataset_table_fixture + + statement = f"PRAGMA table_info({table.name});" + response = dp.run_sql(statement, handler=lambda x: x.first()) + assert response.fetchall() == [] + + with pytest.raises(ValueError) as exc_info: + dp.create_table(table) + assert exc_info.match("File or Dataframe is required for creating table using schema autodetection") + + +@pytest.mark.integration +@pytest.mark.parametrize( + "dataset_table_fixture", + [ + { + "dataset": "SqliteDataProvider", + }, + ], + indirect=True, + ids=["sqlite"], +) +def test_load_pandas_dataframe_to_table(dataset_table_fixture): + """Load Pandas Dataframe to a SQL table""" + database, table = dataset_table_fixture + + pandas_dataframe = pd.DataFrame(data={"id": [1, 2]}) + database.load_pandas_dataframe_to_table(pandas_dataframe, table) + + statement = f"SELECT * FROM {table.name};" + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + + rows = response.fetchall() + assert len(rows) == 2 + assert rows[0] == (1,) + assert rows[1] == (2,) diff --git a/tests/test_data_provider/test_filesystem/test_sftp.py b/tests/test_data_provider/test_filesystem/test_sftp.py index 6d0f63a..3b7ccaf 100644 --- a/tests/test_data_provider/test_filesystem/test_sftp.py +++ b/tests/test_data_provider/test_filesystem/test_sftp.py @@ -51,7 +51,11 @@ def test_sftp_write(): remote_filepath = f"sftp://upload/{file_name}" dataprovider = create_dataprovider(dataset=File(path=remote_filepath, conn_id="sftp_conn")) - fs = FileStream(remote_obj_buffer=open(local_filepath), actual_filename=local_filepath) + fs = FileStream( + remote_obj_buffer=open(local_filepath), + actual_filename=local_filepath, + actual_file=File(local_filepath), + ) dataprovider.write(source_ref=fs) downloaded_file = f"/tmp/{file_name}" diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 7e02769..0f9e104 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,6 +1,9 @@ import random import string +import pandas as pd +from pandas.testing import assert_frame_equal + def create_unique_str(length: int = 50) -> str: """ @@ -12,3 +15,13 @@ def create_unique_str(length: int = 50) -> str: random.choice(string.ascii_lowercase + string.digits) for _ in range(length - 1) ) return unique_id + + +def assert_dataframes_are_equal(df: pd.DataFrame, expected: pd.DataFrame) -> None: + """ + Auxiliary function to compare similarity of dataframes to avoid repeating this logic in many tests. + """ + df = df.rename(columns=str.lower) + df = df.astype({"id": "int64"}) + expected = expected.astype({"id": "int64"}) + assert_frame_equal(df, expected)