From 62c895a93c746a1536557701af90195c09653948 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Fri, 13 Jan 2023 16:43:54 +0545 Subject: [PATCH] Add universal transfer operator interfaces with working example (#1492) # Description ## What is the current behavior? Closing the PR https://github.com/astronomer/astro-sdk/pull/1267 and continuing the change in this fresh PR. closes: #1139 closes: #1544 closes: #1545 closes: #1546 closes: #1551 ## What is the new behavior? - Define interfaces - Use Airflow 2.4 Dataset concept to build more types of Datasets: - Table - File - Dataframe - API - Define an interface for the universal transfer operator - Add the `TransferParameters` class to pass transfer configurations. - Use context manager from DataProvider for clean up. - Introduce three transfer modes - `native`, `non-native` and `third-party`. - `DataProviders` - Add interface for `DataProvider`. - Add interface for `BaseFilesystemProviders`. - Add `read` and `write` methods in `DataProviders` with the context manager. - `TransferIntegrations` and third-party transfers - Add interface for `TransferIntegrations` and introduce the third-party transfer approach - Non-native transfers - Add `Dataprovider` for S3 and GCS. - Add a transfer workflow for S3 to GCS using a non-native approach. - Add a transfer workflow for GCS to S3 using a non-native approach. - Add example DAG for S3 to GCS implementation. - Add example DAG for GCS to S3 implementation. - Third-party transfers - Add `FivetranTransferIntegration` class for all transfers using Fivetran. - Implement `FivetranOptions` which inherits from `TransferParameters` class to pass transfer configurations. - Implement a POC for Fivetran integration - Add example DAG for Fivetran implementation - Fivetran POC with working DAG for transfer example (S3 to Snowflake) when `connector_id` is passed. - Document the APIs for Fivetran transfers on the notion here: https://www.notion.so/astronomerio/Fivetran-3bd9ecfbdcae411faa49cb38595a4571 - MakeFile and Dockerfile along with docker-compose.yaml to build it locally and on the container ## Does this introduce a breaking change? No ### Checklist - [x] Example DAG - [x] Created tests which fail without the change (if possible) - [x] Extended the README/documentation, if necessary Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kaxil Naik Co-authored-by: Felix Uellendall --- Makefile | 18 ++ README.md | 4 + dev/Dockerfile | 18 ++ dev/docker-compose.yaml | 187 +++++++++++ .../example_universal_transfer_operator.py | 93 ++++++ mk/container.mk | 33 ++ mk/local.mk | 71 ++++ pyproject.toml | 129 ++++++++ src/universal_transfer_operator/__init__.py | 20 ++ src/universal_transfer_operator/constants.py | 99 ++++++ .../data_providers/__init__.py | 31 ++ .../data_providers/base.py | 101 ++++++ .../data_providers/filesystem/__init__.py | 0 .../data_providers/filesystem/aws/__init__.py | 0 .../data_providers/filesystem/aws/s3.py | 160 +++++++++ .../data_providers/filesystem/base.py | 110 +++++++ .../filesystem/google/__init__.py | 0 .../filesystem/google/cloud/__init__.py | 0 .../filesystem/google/cloud/gcs.py | 166 ++++++++++ .../datasets/__init__.py | 0 .../datasets/apis.py | 17 + .../datasets/base.py | 11 + .../datasets/dataframe.py | 21 ++ .../datasets/file.py | 116 +++++++ .../datasets/table.py | 91 ++++++ .../integrations/__init__.py | 30 ++ .../integrations/base.py | 45 +++ .../integrations/fivetran.py | 306 ++++++++++++++++++ .../universal_transfer_operator.py | 69 ++++ src/universal_transfer_operator/utils.py | 60 ++++ 30 files changed, 2006 insertions(+) create mode 100644 Makefile create mode 100644 README.md create mode 100644 dev/Dockerfile create mode 100644 dev/docker-compose.yaml create mode 100644 example_dags/example_universal_transfer_operator.py create mode 100644 mk/container.mk create mode 100644 mk/local.mk create mode 100644 pyproject.toml create mode 100644 src/universal_transfer_operator/__init__.py create mode 100644 src/universal_transfer_operator/constants.py create mode 100644 src/universal_transfer_operator/data_providers/__init__.py create mode 100644 src/universal_transfer_operator/data_providers/base.py create mode 100644 src/universal_transfer_operator/data_providers/filesystem/__init__.py create mode 100644 src/universal_transfer_operator/data_providers/filesystem/aws/__init__.py create mode 100644 src/universal_transfer_operator/data_providers/filesystem/aws/s3.py create mode 100644 src/universal_transfer_operator/data_providers/filesystem/base.py create mode 100644 src/universal_transfer_operator/data_providers/filesystem/google/__init__.py create mode 100644 src/universal_transfer_operator/data_providers/filesystem/google/cloud/__init__.py create mode 100644 src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py create mode 100644 src/universal_transfer_operator/datasets/__init__.py create mode 100644 src/universal_transfer_operator/datasets/apis.py create mode 100644 src/universal_transfer_operator/datasets/base.py create mode 100644 src/universal_transfer_operator/datasets/dataframe.py create mode 100644 src/universal_transfer_operator/datasets/file.py create mode 100644 src/universal_transfer_operator/datasets/table.py create mode 100644 src/universal_transfer_operator/integrations/__init__.py create mode 100644 src/universal_transfer_operator/integrations/base.py create mode 100644 src/universal_transfer_operator/integrations/fivetran.py create mode 100644 src/universal_transfer_operator/universal_transfer_operator.py create mode 100644 src/universal_transfer_operator/utils.py diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fbaa2b1 --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +.PHONY: help + +.DEFAULT_GOAL:= help + +target = help + +ifdef "$(target)" + target = $(target) +endif + +container: ## Set up Airflow in container + @$(MAKE) -C mk -f container.mk $(target) + +local: ## Set up local dev env + @$(MAKE) -C mk -f local.mk $(target) + +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-41s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md new file mode 100644 index 0000000..0d4b44b --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +Universal Transfer Operator + +To build locally: +```make container target=build-run``` diff --git a/dev/Dockerfile b/dev/Dockerfile new file mode 100644 index 0000000..9682ed7 --- /dev/null +++ b/dev/Dockerfile @@ -0,0 +1,18 @@ +FROM quay.io/astronomer/astro-runtime:7.1.0-base + +USER root +RUN apt-get update -y && apt-get install -y git +RUN apt-get install -y --no-install-recommends \ + build-essential \ + libsasl2-2 \ + libsasl2-dev \ + libsasl2-modules +ENV SETUPTOOLS_USE_DISTUTILS=stdlib + +COPY ../pyproject.toml ${AIRFLOW_HOME}/universal_transfer_operator/ +# The following file are needed because version they are referenced from pyproject.toml +COPY ../README.md ${AIRFLOW_HOME}/universal_transfer_operator/ +COPY ../src/universal_transfer_operator/__init__.py ${AIRFLOW_HOME}/universal_transfer_operator/src/universal_transfer_operator/__init__.py + +RUN pip install -e "${AIRFLOW_HOME}/universal_transfer_operator[all]" +USER astro diff --git a/dev/docker-compose.yaml b/dev/docker-compose.yaml new file mode 100644 index 0000000..30e10c9 --- /dev/null +++ b/dev/docker-compose.yaml @@ -0,0 +1,187 @@ +--- +version: '3' +x-airflow-common: + &airflow-common + image: astro-sdk-dev + build: + context: .. + dockerfile: dev/Dockerfile + environment: + &airflow-common-env + DB_BACKEND: postgres + AIRFLOW__CORE__EXECUTOR: CeleryExecutor + AIRFLOW__DATABASE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres:5432/airflow + AIRFLOW__CELERY__RESULT_BACKEND: db+postgresql://airflow:airflow@postgres:5432/airflow + AIRFLOW__CELERY__BROKER_URL: redis://:@redis:6379/0 + AIRFLOW__CORE__FERNET_KEY: '' + AIRFLOW__CORE__LOAD_EXAMPLES: "False" + AIRFLOW__WEBSERVER__EXPOSE_CONFIG: "True" + AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL: "5" + ASTRONOMER_ENVIRONMENT: local + AIRFLOW__CORE__ALLOWED_DESERIALIZATION_CLASSES: airflow.* astro.* + AIRFLOW__LINEAGE__BACKEND: openlineage.lineage_backend.OpenLineageBackend + OPENLINEAGE_URL: http://host.docker.internal:5050/ + OPENLINEAGE_NAMESPACE: "astro" + volumes: + - ./dags:/usr/local/airflow/dags + - ./logs:/usr/local/airflow/logs + - ./plugins:/usr/local/airflow/plugins + - ../../universal_transfer_operator:/usr/local/airflow/universal_transfer_operator + depends_on: + &airflow-common-depends-on + redis: + condition: service_healthy + postgres: + condition: service_healthy + +services: + postgres: + image: postgres:13 + environment: + POSTGRES_USER: airflow + POSTGRES_PASSWORD: airflow + POSTGRES_DB: airflow + command: postgres -c 'idle_in_transaction_session_timeout=60000' # 1 minute timeout + volumes: + - postgres-db-volume:/var/lib/postgresql/data + ports: + - "5432:5432" + healthcheck: + test: ["CMD", "pg_isready", "-U", "airflow"] + interval: 5s + retries: 5 + restart: always + + redis: + image: redis:latest + expose: + - 6379 + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 30s + retries: 50 + restart: always + + airflow-webserver: + <<: *airflow-common + command: airflow webserver + ports: + - 8080:8080 + healthcheck: + test: ["CMD", "curl", "--fail", "http://localhost:8080/health"] + interval: 10s + timeout: 10s + retries: 5 + restart: always + depends_on: + <<: *airflow-common-depends-on + airflow-init: + condition: service_completed_successfully + + airflow-scheduler: + <<: *airflow-common + command: airflow scheduler + healthcheck: + test: ["CMD-SHELL", 'airflow jobs check --job-type SchedulerJob --hostname "$${HOSTNAME}"'] + interval: 10s + timeout: 10s + retries: 5 + restart: always + depends_on: + <<: *airflow-common-depends-on + airflow-init: + condition: service_completed_successfully + + airflow-worker: + <<: *airflow-common + command: airflow celery worker + healthcheck: + test: + - "CMD-SHELL" + - 'celery --app airflow.executors.celery_executor.app inspect ping -d "celery@$${HOSTNAME}"' + interval: 10s + timeout: 10s + retries: 5 + environment: + <<: *airflow-common-env + restart: always + depends_on: + <<: *airflow-common-depends-on + airflow-init: + condition: service_completed_successfully + + airflow-triggerer: + <<: *airflow-common + command: airflow triggerer + healthcheck: + test: ["CMD-SHELL", 'airflow jobs check --job-type TriggererJob --hostname "$${HOSTNAME}"'] + interval: 10s + timeout: 10s + retries: 5 + restart: always + depends_on: + <<: *airflow-common-depends-on + airflow-init: + condition: service_completed_successfully + + airflow-init: + <<: *airflow-common + entrypoint: /bin/bash + # yamllint disable rule:line-length + command: + - -c + - | + one_meg=1048576 + mem_available=$$(($$(getconf _PHYS_PAGES) * $$(getconf PAGE_SIZE) / one_meg)) + cpus_available=$$(grep -cE 'cpu[0-9]+' /proc/stat) + disk_available=$$(df / | tail -1 | awk '{print $$4}') + warning_resources="false" + if (( mem_available < 4000 )) ; then + echo + echo -e "\033[1;33mWARNING!!!: Not enough memory available for Docker.\e[0m" + echo "At least 4GB of memory required. You have $$(numfmt --to iec $$((mem_available * one_meg)))" + echo + warning_resources="true" + fi + if (( cpus_available < 2 )); then + echo + echo -e "\033[1;33mWARNING!!!: Not enough CPUS available for Docker.\e[0m" + echo "At least 2 CPUs recommended. You have $${cpus_available}" + echo + fi + if (( disk_available < one_meg * 10 )); then + echo + echo -e "\033[1;33mWARNING!!!: Not enough Disk space available for Docker.\e[0m" + echo "At least 10 GBs recommended. You have $$(numfmt --to iec $$((disk_available * 1024 )))" + echo + fi + exec /entrypoint bash -c " + airflow db upgrade && \ + airflow users create -r Admin -u admin -e admin -f admin -l admin -p admin && \ + airflow connections import /usr/local/airflow/universal_transfer_operator/dev/connections.yaml || true && \ + airflow version" + # yamllint enable rule:line-length + environment: + <<: *airflow-common-env + + flower: + <<: *airflow-common + command: airflow celery flower + ports: + - 5555:5555 + healthcheck: + test: ["CMD", "curl", "--fail", "http://localhost:5555/"] + interval: 10s + timeout: 10s + retries: 5 + environment: + <<: *airflow-common-env + restart: always + depends_on: + <<: *airflow-common-depends-on + airflow-init: + condition: service_completed_successfully + +volumes: + postgres-db-volume: diff --git a/example_dags/example_universal_transfer_operator.py b/example_dags/example_universal_transfer_operator.py new file mode 100644 index 0000000..f7d7b34 --- /dev/null +++ b/example_dags/example_universal_transfer_operator.py @@ -0,0 +1,93 @@ +import os +from datetime import datetime + +from airflow import DAG + +from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.datasets.file import File +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.integrations.fivetran import Connector, Destination, FiveTranOptions, Group +from universal_transfer_operator.universal_transfer_operator import UniversalTransferOperator + +with DAG( + "example_universal_transfer_operator", + schedule_interval=None, + start_date=datetime(2022, 1, 1), + catchup=False, +) as dag: + transfer_non_native_gs_to_s3 = UniversalTransferOperator( + task_id="transfer_non_native_gs_to_s3", + source_dataset=File(path="gs://uto-test/uto/", conn_id="google_cloud_default"), + destination_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + ) + + transfer_non_native_s3_to_gs = UniversalTransferOperator( + task_id="transfer_non_native_s3_to_gs", + source_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + destination_dataset=File( + path="gs://uto-test/uto/", + conn_id="google_cloud_default", + ), + ) + + transfer_fivetran_with_connector_id = UniversalTransferOperator( + task_id="transfer_fivetran_with_connector_id", + source_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + destination_dataset=Table(name="fivetran_test", conn_id="snowflake_default"), + transfer_mode=TransferMode.THIRDPARTY, + transfer_params=FiveTranOptions(conn_id="fivetran_default", connector_id="filing_muppet"), + ) + + transfer_fivetran_without_connector_id = UniversalTransferOperator( + task_id="transfer_fivetran_without_connector_id", + source_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), + destination_dataset=Table( + name="fivetran_test", + conn_id="snowflake_default", + metadata=Metadata( + database=os.environ["SNOWFLAKE_DATABASE"], schema=os.environ["SNOWFLAKE_SCHEMA"] + ), + ), + transfer_mode=TransferMode.THIRDPARTY, + transfer_params=FiveTranOptions( + conn_id="fivetran_default", + connector_id="filing_muppet", + group=Group(name="test_group"), + connector=Connector( + service="s3", + config={ + "schema": "s3", + "append_file_option": "upsert_file", + "prefix": "folder_path", + "pattern": "file_pattern", + "escape_char": "", + "skip_after": 0, + "list_strategy": "complete_listing", + "bucket": "astro-sdk-test", + "empty_header": True, + "skip_before": 0, + "role_arn": "arn::your_role_arn", + "file_type": "csv", + "delimiter": "", + "is_public": False, + "on_error": "fail", + "compression": "bz2", + "table": "fivetran_test", + "archive_pattern": "regex_pattern", + }, + ), + destination=Destination( + service="snowflake", + time_zone_offset="-5", + region="GCP_US_EAST4", + config={ + "host": "your-account.snowflakecomputing.com", + "port": 443, + "database": "fivetran", + "auth": "PASSWORD", + "user": "fivetran_user", + "password": "123456", + }, + ), + ), + ) diff --git a/mk/container.mk b/mk/container.mk new file mode 100644 index 0000000..d47ed71 --- /dev/null +++ b/mk/container.mk @@ -0,0 +1,33 @@ +PHONY: build-run clean docs logs stop shell restart restart-all help + +.DEFAULT_GOAL:= help + +logs: ## View logs of the all the containers + docker compose -f ../dev/docker-compose.yaml logs --follow + +stop: ## Stop all the containers + docker compose -f ../dev/docker-compose.yaml down + +clean: ## Remove all the containers along with volumes + docker compose -f ../dev/docker-compose.yaml down --volumes --remove-orphans + rm -rf dev/logs + +build-run: ## Build the Docker Image & then run the containers + docker compose -f ../dev/docker-compose.yaml up --build -d + +docs: ## Build the docs using Sphinx + docker compose -f ../dev/docker-compose.yaml build + docker compose -f ../dev/docker-compose.yaml run --entrypoint /bin/bash airflow-init -c "cd universal_transfer_operator/docs && make clean html" + @echo "Documentation built in $(shell cd .. && pwd)/docs/_build/html/index.html" + +restart: ## Restart Triggerer, Scheduler and Worker containers + docker compose -f ../dev/docker-compose.yaml restart airflow-triggerer airflow-scheduler airflow-worker + +restart-all: ## Restart all the containers + docker compose -f ../dev/docker-compose.yaml restart + +shell: ## Runs a shell within a container (Allows interactive session) + docker compose -f ../dev/docker-compose.yaml run --rm airflow-scheduler bash + +help: ## Prints this message + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-41s\033[0m %s\n", $$1, $$2}' diff --git a/mk/local.mk b/mk/local.mk new file mode 100644 index 0000000..8eec2a8 --- /dev/null +++ b/mk/local.mk @@ -0,0 +1,71 @@ +.PHONY: help +.DEFAULT_GOAL:= help +SHELL := /bin/bash +PROJECT_NAME := universal-transfer-operator +SYSTEM_PYTHON := python3.9 + +# Set default virtualenv path, if not defined +ifndef VIRTUALENV_PATH +$(shell mkdir -p ~/.virtualenvs/) +override VIRTUALENV_PATH = ~/.virtualenvs/$(PROJECT_NAME) +endif + +PYTHON = $(VIRTUALENV_PATH)/bin/python +PIP = $(VIRTUALENV_PATH)/bin/pip +PYTEST = $(VIRTUALENV_PATH)/bin/pytest +PRECOMMIT = $(VIRTUALENV_PATH)/bin/pre-commit + + +clean: ## Remove temporary files + @echo "Removing cached and temporary files from current directory" + @rm -rf logs + @find . -name "*.pyc" -delete + @find . -type d -name "__pycache__" -exec rm -rf {} + + @find . -name "*.sw[a-z]" -delete + @find . -type d -name "*.egg-info" -exec rm -rf {} + + +virtualenv: ## Create Python virtualenv + @test -d $(VIRTUALENV_PATH) && \ + (echo "The virtualenv $(VIRTUALENV_PATH) already exists. Skipping.") || \ + (echo "Creating the virtualenv $(VIRTUALENV_PATH) using $(SYSTEM_PYTHON)" & \ + $(SYSTEM_PYTHON) -m venv $(VIRTUALENV_PATH)) + +install: virtualenv ## Install python dependencies in existing virtualenv + @echo "Installing Python dependencies using $(PIP)" + @$(PIP) install --upgrade pip + @$(PIP) install nox + @$(PIP) install pre-commit + @cd .. && $(PIP) install -e .[all,tests,doc] + +config: ## Create sample configuration files related to Snowflake, Amazon and Google + @cd .. && test -e .env && \ + (echo "The file .env already exist. Skipping.") || \ + (echo "Creating .env..." && \ + cat .env-template > .env && \ + echo "Please, update .env with your credentials") + @cd .. && test -e test-connections.yaml && \ + (echo "The file test-connections.yaml already exist. Skipping.") || \ + (echo "Creating test-connections.yaml..." && \ + cat .github/ci-test-connections.yaml > test-connections.yaml && \ + echo "Please, update test-connections.yaml with your credentials") + +setup: config virtualenv install ## Setup a local development environment + +quality: + @$(PRECOMMIT) run --all-files + +test: virtualenv config ## Run all tests (use option: db=[db] run only run database-specific ones) +ifdef db + @cd .. && $(PYTEST) -s --cov --cov-branch --cov-report=term-missing -m "$(db)" +else + @cd .. && $(PYTEST) -s --cov --cov-branch --cov-report=term-missing +endif + +unit: virtualenv config ## Run unit tests + @cd .. && $(PYTEST) -s --cov --cov-branch --cov-report=term-missing -m "not integration" + +integration: virtualenv config ## Run integration tests + @cd .. && $(PYTEST) -s --cov --cov-branch --cov-report=term-missing -m integration + +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7a4d902 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,129 @@ +[build-system] +requires = ["flit_core ~=3.2"] +build-backend = "flit_core.buildapi" + +[project] +name = "universal-transfer-operator" +dynamic = ["version"] +description = """Universal Transfer Operator transfers all the data that could be read from the source Dataset into +the destination Dataset. From a DAG author standpoint, all transfers would be performed through the invocation of +only the Universal Transfer Operator.""" + +authors = [ + { name = "Astronomer", email = "humans@astronomer.io" }, +] +readme = "README.md" +license = { file = "LICENSE" } + +requires-python = ">=3.7" +dependencies = [ + "apache-airflow>=2.0", + "attrs>=20.0", + "pandas>=1.3.4,<2.0.0", # Pinning it to <2.0.0 to avoid breaking changes + "pyarrow", + "python-frontmatter", + "smart-open", + "SQLAlchemy>=1.3.18", + "apache-airflow-providers-common-sql" +] + +keywords = ["airflow", "provider", "astronomer", "sql", "decorator", "task flow", "elt", "etl", "dag"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Topic :: Database", + "Framework :: Apache Airflow", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", +] + +[project.optional-dependencies] +tests = [ + "pytest>=6.0", + "sqlalchemy-stubs", # Change when sqlalchemy is upgraded https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html +] +google = [ + "protobuf<=3.20", # Google bigquery client require protobuf <= 3.20.0. We can remove the limitation when this limitation is removed + "apache-airflow-providers-google>=6.4.0", + "sqlalchemy-bigquery>=1.3.0", + "smart-open[gcs]>=5.2.1" +] +snowflake = [ + "apache-airflow-providers-snowflake", + "snowflake-sqlalchemy>=1.2.0", + "snowflake-connector-python[pandas]", +] + +amazon = [ + "apache-airflow-providers-amazon>=5.0.0", + "s3fs", + "smart-open[s3]>=5.2.1", +] + +openlineage = ["openlineage-airflow>=0.17.0"] + +fivetran = ["airflow-provider-fivetran>=1.1.3"] + +all = [ + "apache-airflow-providers-amazon", + "apache-airflow-providers-google>=6.4.0", + "apache-airflow-providers-snowflake", + "smart-open[all]>=5.2.1", + "snowflake-connector-python[pandas]", + "snowflake-sqlalchemy>=1.2.0", + "sqlalchemy-bigquery>=1.3.0", + "s3fs", + "protobuf<=3.20", # Google bigquery client require protobuf <= 3.20.0. We can remove the limitation when this limitation is removed + "openlineage-airflow>=0.17.0", + "airflow-provider-fivetran>=1.1.3", +] +doc = [ + "myst-parser>=0.17", + "sphinx>=4.4.0", + "sphinx-autoapi>=2.0.0", + "sphinx-rtd-theme" +] + +[project.urls] +Home = "https://astronomer.io/" +Source = "https://github.com/astronomer/astro-sdk/universal_transfer_operator" + +[project.entry-points.apache_airflow_provider] +provider_info = "universal_transfer_operator.__init__:get_provider_info" + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "--durations=30 --durations-min=1.0" +env_files = [".env"] +testpaths = ["tests", "tests_integration"] +markers = [ + "integration" +] + + +[tool.flit.module] +name = "universal_transfer_operator" + +[tool.mypy] +color_output = true +#disallow_any_generics = true +#disallow_incomplete_defs = true +#disallow_untyped_defs = true +files = ["src/universal_transfer_operator"] +follow_imports = "skip" +no_implicit_optional = true +pretty = true +strict_equality = true +show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true + +[tool.black] +line-length = 110 +target-version = ['py37', 'py38', 'py39', 'py310'] diff --git a/src/universal_transfer_operator/__init__.py b/src/universal_transfer_operator/__init__.py new file mode 100644 index 0000000..2c1a793 --- /dev/null +++ b/src/universal_transfer_operator/__init__.py @@ -0,0 +1,20 @@ +"""An Operator that allows transfers between different datasets.""" + +__version__ = "0.0.1dev1" + + +# This is needed to allow Airflow to pick up specific metadata fields it needs +# for certain features. We recognize it's a bit unclean to define these in +# multiple places, but at this point it's the only workaround if you'd like +# your custom conn type to show up in the Airflow UI. +def get_provider_info() -> dict: + return { + # Required. + "package-name": "universal-transfer-operator", + "name": "Universal Transfer Operator", + "description": __doc__, + "versions": [__version__], + # Optional. + "hook-class-names": [], + "extra-links": [], + } diff --git a/src/universal_transfer_operator/constants.py b/src/universal_transfer_operator/constants.py new file mode 100644 index 0000000..ed9c909 --- /dev/null +++ b/src/universal_transfer_operator/constants.py @@ -0,0 +1,99 @@ +import sys +from enum import Enum + +# typing.Literal was only introduced in Python 3.8, and we support Python 3.7 +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + + +class Location(Enum): + LOCAL = "local" + HTTP = "http" + HTTPS = "https" + GS = "gs" # Google Cloud Storage + GOOGLE_DRIVE = "gdrive" + S3 = "s3" # Amazon S3 + WASB = "wasb" # Azure Blob Storage + WASBS = "wasbs" # Azure Blob Storage + POSTGRES = "postgres" + POSTGRESQL = "postgres" + SQLITE = "sqlite" + DELTA = "delta" + BIGQUERY = "bigquery" + SNOWFLAKE = "snowflake" + REDSHIFT = "redshift" + + def __repr__(self): + return f"{self}" + + +class FileLocation(Enum): + # [START filelocation] + LOCAL = "local" + HTTP = "http" + HTTPS = "https" + GS = "gs" # Google Cloud Storage + GOOGLE_DRIVE = "gdrive" + S3 = "s3" # Amazon S3 + WASB = "wasb" # Azure Blob Storage + WASBS = "wasbs" # Azure Blob Storage + # [END filelocation] + + def __repr__(self): + return f"{self}" + + +class IngestorSupported(Enum): + # [START transferingestor] + Fivetran = "fivetran" + # [END transferingestor] + + def __repr__(self): + return f"{self}" + + +class TransferMode(Enum): + # [START TransferMode] + NATIVE = "native" + NONNATIVE = "nonnative" + THIRDPARTY = "thirdparty" + # [END TransferMode] + + def __str__(self) -> str: + return self.value + + +class FileType(Enum): + # [START filetypes] + CSV = "csv" + JSON = "json" + NDJSON = "ndjson" + PARQUET = "parquet" + # [END filetypes] + + def __repr__(self): + return f"{self}" + + +class Database(Enum): + # [START database] + POSTGRES = "postgres" + POSTGRESQL = "postgres" + SQLITE = "sqlite" + DELTA = "delta" + BIGQUERY = "bigquery" + SNOWFLAKE = "snowflake" + REDSHIFT = "redshift" + # [END database] + + def __repr__(self): + return f"{self}" + + +SUPPORTED_FILE_LOCATIONS = [const.value for const in FileLocation] +SUPPORTED_FILE_TYPES = [const.value for const in FileType] +SUPPORTED_DATABASES = [const.value for const in Database] + +LoadExistStrategy = Literal["replace", "append"] diff --git a/src/universal_transfer_operator/data_providers/__init__.py b/src/universal_transfer_operator/data_providers/__init__.py new file mode 100644 index 0000000..9c50929 --- /dev/null +++ b/src/universal_transfer_operator/data_providers/__init__.py @@ -0,0 +1,31 @@ +import importlib + +from airflow.hooks.base import BaseHook + +from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.data_providers.base import DataProviders +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.utils import TransferParameters, get_class_name + +DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING = dict.fromkeys( + ["s3", "aws"], "universal_transfer_operator.data_providers.filesystem.aws.s3" +) | dict.fromkeys( + ["gs", "google_cloud_platform"], "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs" +) + + +def create_dataprovider( + dataset: Dataset, + transfer_params: TransferParameters = None, + transfer_mode: TransferMode = TransferMode.NONNATIVE, +) -> DataProviders: + conn_type = BaseHook.get_connection(dataset.conn_id).conn_type + module_path = DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING[conn_type] + module = importlib.import_module(module_path) + class_name = get_class_name(module_ref=module, suffix="DataProvider") + data_provider: DataProviders = getattr(module, class_name)( + dataset=dataset, + transfer_params=transfer_params, + transfer_mode=transfer_mode, + ) + return data_provider diff --git a/src/universal_transfer_operator/data_providers/base.py b/src/universal_transfer_operator/data_providers/base.py new file mode 100644 index 0000000..ad03e90 --- /dev/null +++ b/src/universal_transfer_operator/data_providers/base.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from abc import ABC +from contextlib import contextmanager + +import attr +from airflow.hooks.base import BaseHook + +from universal_transfer_operator.constants import Location +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.utils import TransferParameters, get_dataset_connection_type + + +class DataProviders(ABC): + """ + Base class to represent all the DataProviders interactions with Dataset. + + The goal is to be able to support new dataset by adding + a new module to the `uto/data_providers` directory, without the need of + changing other modules and classes. + """ + + def __init__( + self, + dataset: Dataset, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping: set[Location] = set() + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def hook(self) -> BaseHook: + """Return an instance of the Airflow hook.""" + raise NotImplementedError + + def check_if_exists(self) -> bool: + """Return true if the dataset exists""" + raise NotImplementedError + + def check_if_transfer_supported(self, source_dataset: Dataset) -> bool: + """ + Checks if the transfer is supported from source to destination based on source_dataset. + """ + source_connection_type = get_dataset_connection_type(source_dataset) + return Location(source_connection_type) in self.transfer_mapping + + @contextmanager + def read(self): + """Read the dataset and write to local reference location""" + raise NotImplementedError + + def write(self, source_ref): + """Write the data from local reference location to the dataset""" + raise NotImplementedError + + def load_data_from_source_natively(self, source_dataset: Dataset, destination_dataset: Dataset) -> None: + """ + Loads data from source dataset to the destination using data provider + """ + if not self.check_if_transfer_supported(source_dataset=source_dataset): + raise ValueError("Transfer not supported yet.") + + source_connection_type = get_dataset_connection_type(source_dataset) + destination_connection_type = get_dataset_connection_type(destination_dataset) + method_name = self.LOAD_DATA_FROM_SOURCE.get(source_connection_type) + if method_name: + transfer_method = self.__getattribute__(method_name) + return transfer_method( + source_dataset=source_dataset, + destination_dataset=destination_dataset, + ) + else: + raise ValueError( + f"No transfer performed from {source_connection_type} to {destination_connection_type}." + ) + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/data_providers/filesystem/__init__.py b/src/universal_transfer_operator/data_providers/filesystem/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/universal_transfer_operator/data_providers/filesystem/aws/__init__.py b/src/universal_transfer_operator/data_providers/filesystem/aws/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py b/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py new file mode 100644 index 0000000..74bfa0b --- /dev/null +++ b/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import os +from functools import cached_property +from tempfile import NamedTemporaryFile + +import attr +from airflow.providers.amazon.aws.hooks.s3 import S3Hook + +from universal_transfer_operator.constants import Location, TransferMode +from universal_transfer_operator.data_providers.filesystem.base import ( + BaseFilesystemProviders, + Path, + TempFile, + contextmanager, +) +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.utils import TransferParameters + + +class S3DataProvider(BaseFilesystemProviders): + """ + DataProviders interactions with S3 Dataset. + """ + + def __init__( + self, + dataset: Dataset, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + transfer_mode: TransferMode = TransferMode.NONNATIVE, + ): + super().__init__( + dataset=dataset, + transfer_params=transfer_params, + transfer_mode=transfer_mode, + ) + self.transfer_mapping: set = { + Location.S3, + Location.GS, + } + + @cached_property + def hook(self) -> S3Hook: + """Return an instance of the database-specific Airflow hook.""" + return S3Hook( + aws_conn_id=self.dataset.conn_id, + verify=self.verify, + transfer_config_args=self.transfer_config_args, + extra_args=self.s3_extra_args, + ) + + def check_if_exists(self) -> bool: + """Return true if the dataset exists""" + return self.hook.check_for_key(key=self.dataset.path) + + @contextmanager + def read(self) -> list[TempFile]: + """Read the file from dataset and write to local file location""" + if not self.check_if_exists(): + raise ValueError(f"{self.dataset.path} doesn't exits") + files = self.hook.list_keys( + bucket_name=self.bucket_name, + prefix=self.s3_key, + delimiter=self.delimiter, + ) + local_file_paths = [] + try: + for file in files: + local_file_paths.append(self.download_file(file)) + yield local_file_paths + finally: + # Clean up the local files + self.cleanup(local_file_paths) + + def write(self, source_ref: list[TempFile]): + """Write the file from local file location to the dataset""" + + dest_s3_key = self.dataset.path + + if not self.keep_directory_structure and self.prefix: + dest_s3_key = os.path.join(dest_s3_key, self.prefix) + + destination_keys = [] + for file in source_ref: + if file.tmp_file.exists(): + dest_key = os.path.join(dest_s3_key, os.path.basename(file.actual_filename.name)) + self.hook.load_file( + filename=file.tmp_file.as_posix(), + key=dest_key, + replace="replace", + acl_policy=self.s3_acl_policy, + ) + destination_keys.append(dest_key) + + return destination_keys + + def download_file(self, file) -> TempFile: + """Download file and save to temporary path.""" + file_object = self.hook.get_key(file, self.bucket_name) + _, _, file_name = file.rpartition("/") + with NamedTemporaryFile(suffix=file_name, delete=False) as tmp_file: + file_object.download_fileobj(tmp_file) + return TempFile(tmp_file=Path(tmp_file.name), actual_filename=Path(file_name)) + + @property + def verify(self) -> str | bool | None: + return self.dataset.extra.get("verify", None) + + @property + def transfer_config_args(self) -> dict | None: + return self.dataset.extra.get("transfer_config_args", None) + + @property + def s3_extra_args(self) -> dict | None: + return self.dataset.extra.get("s3_extra_args", {}) + + @property + def bucket_name(self) -> str: + bucket_name, _ = self.hook.parse_s3_url(self.dataset.path) + return bucket_name + + @property + def s3_key(self) -> str: + _, key = self.hook.parse_s3_url(self.dataset.path) + return key + + @property + def s3_acl_policy(self) -> str | None: + return self.dataset.extra.get("s3_acl_policy", None) + + @property + def prefix(self) -> str | None: + return self.dataset.extra.get("prefix", None) + + @property + def keep_directory_structure(self) -> bool: + return self.dataset.extra.get("keep_directory_structure", False) + + @property + def delimiter(self) -> str | None: + return self.dataset.extra.get("delimiter", None) + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/data_providers/filesystem/base.py b/src/universal_transfer_operator/data_providers/filesystem/base.py new file mode 100644 index 0000000..29c4ebe --- /dev/null +++ b/src/universal_transfer_operator/data_providers/filesystem/base.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import attr +from airflow.hooks.dbapi import DbApiHook + +from universal_transfer_operator.constants import Location +from universal_transfer_operator.data_providers.base import DataProviders, contextmanager +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.universal_transfer_operator import TransferParameters +from universal_transfer_operator.utils import get_dataset_connection_type + + +@attr.define +class TempFile: + tmp_file: Path | None + actual_filename: Path + + +class BaseFilesystemProviders(DataProviders): + """BaseFilesystemProviders represent all the DataProviders interactions with File system.""" + + def __init__( + self, + dataset: Dataset, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def hook(self) -> DbApiHook: + """Return an instance of the database-specific Airflow hook.""" + raise NotImplementedError + + def check_if_exists(self) -> bool: + """Return true if the dataset exists""" + raise NotImplementedError + + def check_if_transfer_supported(self, source_dataset: Dataset) -> bool: + """ + Checks if the transfer is supported from source to destination based on source_dataset. + """ + source_connection_type = get_dataset_connection_type(source_dataset) + return Location(source_connection_type) in self.transfer_mapping + + @contextmanager + def read(self) -> list[TempFile]: + """Read the file dataset and write to local file location""" + raise NotImplementedError + + def write(self, source_ref) -> list[TempFile]: + """Write the source data from local file location to the dataset""" + raise NotImplementedError + + @staticmethod + def cleanup(file_list: list[TempFile]) -> None: + """Cleans up the temporary files created""" + for file in file_list: + if os.path.exists(file.tmp_file.name): + os.remove(file.tmp_file.name) + + def load_data_from_source_natively(self, source_dataset: Dataset, destination_dataset: Dataset) -> None: + """ + Loads data from source dataset to the destination using data provider + """ + if not self.check_if_transfer_supported(source_dataset=source_dataset): + raise ValueError("Transfer not supported yet.") + + source_connection_type = get_dataset_connection_type(source_dataset) + method_name = self.LOAD_DATA_NATIVELY_FROM_SOURCE.get(source_connection_type) + if method_name: + transfer_method = self.__getattribute__(method_name) + return transfer_method( + source_dataset=source_dataset, + destination_dataset=destination_dataset, + ) + else: + raise ValueError(f"No transfer performed from {source_connection_type} to S3.") + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/data_providers/filesystem/google/__init__.py b/src/universal_transfer_operator/data_providers/filesystem/google/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/universal_transfer_operator/data_providers/filesystem/google/cloud/__init__.py b/src/universal_transfer_operator/data_providers/filesystem/google/cloud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py b/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py new file mode 100644 index 0000000..c5d81c1 --- /dev/null +++ b/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import logging +from functools import cached_property +from tempfile import NamedTemporaryFile +from typing import Sequence + +import attr +from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url + +from universal_transfer_operator.constants import Location, TransferMode +from universal_transfer_operator.data_providers.filesystem.base import ( + BaseFilesystemProviders, + Path, + TempFile, + contextmanager, +) +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.utils import TransferParameters + + +class GCSDataProvider(BaseFilesystemProviders): + """ + DataProviders interactions with GS Dataset. + """ + + def __init__( + self, + dataset: Dataset, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + transfer_mode: TransferMode = TransferMode.NONNATIVE, + ): + super().__init__( + dataset=dataset, + transfer_params=transfer_params, + transfer_mode=transfer_mode, + ) + self.transfer_mapping = { + Location.S3, + Location.GS, + } + + @cached_property + def hook(self) -> GCSHook: + """Return an instance of the database-specific Airflow hook.""" + return GCSHook( + gcp_conn_id=self.dataset.conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.google_impersonation_chain, + ) + + def check_if_exists(self) -> bool: + """Return true if the dataset exists""" + return self.hook.exists(bucket_name=self.bucket_name, object_name=self.blob_name) + + @contextmanager + def read(self) -> list[TempFile]: + """Read the file from dataset and write to local file location""" + if not self.check_if_exists(): + raise ValueError(f"{self.dataset.path} doesn't exits") + + logging.info( + "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s", + self.bucket_name, # type: ignore + self.delimiter, + self.prefix, + ) + files = self.hook.list( + bucket_name=self.bucket_name, # type: ignore + prefix=self.prefix, + delimiter=self.delimiter, + ) + + try: + local_file_paths = [] + if files: + for file in files: + local_file_paths.append(self.download_file(file)) + yield local_file_paths + finally: + # Clean up the local files + self.cleanup(local_file_paths) + + def write(self, source_ref: list[TempFile]) -> list[str]: + """Write the file from local file location to the dataset""" + destination_objects = [] + if source_ref: + for file in source_ref: + destination_objects.append(self.upload_file(file)) + logging.info("All done, uploaded %d files to Google Cloud Storage", len(source_ref)) + else: + logging.info("In sync, no files needed to be uploaded to Google Cloud Storage") + return destination_objects + + def upload_file(self, file: TempFile): + """Upload file to GCS and return path""" + # There will always be a '/' before file because it is + # enforced at instantiation time + dest_gcs_object = self.blob_name + file.actual_filename.name + self.hook.upload( + bucket_name=self.bucket_name, + object_name=dest_gcs_object, + filename=file.tmp_file.as_posix(), + gzip=self.gzip, + ) + return dest_gcs_object + + def download_file(self, file) -> TempFile: + """Download file and save to temporary path.""" + _, _, file_name = file.rpartition("/") + with NamedTemporaryFile(suffix=file_name, delete=False) as tmp_file: + self.hook.download( + bucket_name=self.bucket_name, + object_name=file, + filename=tmp_file.name, + ) + return TempFile(tmp_file=Path(tmp_file.name), actual_filename=Path(file_name)) + + @property + def delegate_to(self) -> str | None: + return self.dataset.extra.get("delegate_to", None) + + @property + def google_impersonation_chain(self) -> str | Sequence[str] | None: + return self.dataset.extra.get("google_impersonation_chain", None) + + @property + def delimiter(self) -> str | None: + return self.dataset.extra.get("delimiter", None) + + @property + def bucket_name(self) -> str: + bucket_name, _ = _parse_gcs_url(gsurl=self.dataset.path) + return bucket_name + + @property + def prefix(self) -> str | None: + return self.dataset.extra.get("prefix", None) + + @property + def gzip(self) -> bool | None: + return self.dataset.extra.get("gzip", False) + + @property + def blob_name(self) -> str: + _, blob = _parse_gcs_url(gsurl=self.dataset.path) + return blob + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/datasets/__init__.py b/src/universal_transfer_operator/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/universal_transfer_operator/datasets/apis.py b/src/universal_transfer_operator/datasets/apis.py new file mode 100644 index 0000000..6aecd67 --- /dev/null +++ b/src/universal_transfer_operator/datasets/apis.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from attr import define + +from universal_transfer_operator.datasets.base import Dataset + + +@define +class API(Dataset): + """ + Repersents all API dataset. + Intended to be used within library. + + :param path: Path to a file in the filesystem/Object stores + """ + + # TODO: define the name and namespace for API diff --git a/src/universal_transfer_operator/datasets/base.py b/src/universal_transfer_operator/datasets/base.py new file mode 100644 index 0000000..d757db7 --- /dev/null +++ b/src/universal_transfer_operator/datasets/base.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +try: + # Airflow >= 2.4 + from airflow.datasets import Dataset + + DATASET_SUPPORT = True +except ImportError: + # Airflow < 2.4 + Dataset = object + DATASET_SUPPORT = False diff --git a/src/universal_transfer_operator/datasets/dataframe.py b/src/universal_transfer_operator/datasets/dataframe.py new file mode 100644 index 0000000..a15fc08 --- /dev/null +++ b/src/universal_transfer_operator/datasets/dataframe.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from attr import define + +from universal_transfer_operator.datasets.base import Dataset + + +@define +class Dataframe(Dataset): + """ + Repersents all dataframe dataset. + Intended to be used within library. + + :param path: Path to a file in the filesystem/Object stores + :param conn_id: Airflow connection ID + :param name: name of dataframe + """ + + name: str | None = None + + # TODO: define the name and namespace for dataframe diff --git a/src/universal_transfer_operator/datasets/file.py b/src/universal_transfer_operator/datasets/file.py new file mode 100644 index 0000000..fea6e74 --- /dev/null +++ b/src/universal_transfer_operator/datasets/file.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import io +import pathlib + +import pandas as pd +import smart_open +from attr import define, field + +from universal_transfer_operator.constants import FileType +from universal_transfer_operator.datasets.base import Dataset + + +@define +class File(Dataset): + """ + Repersents all file dataset. + + :param path: Path to a file in the filesystem/Object stores + :param conn_id: Airflow connection ID + :param filetype: constant to provide an explicit file type + :param normalize_config: parameters in dict format of pandas json_normalize() function. + :param is_bytes: is bytes + """ + + path: str = field(default="") + conn_id: str = field(default="") + filetype: FileType | None = None + normalize_config: dict | None = None + is_bytes: bool = False + uri: str = field(init=False) + extra: dict = field(init=True, factory=dict) + + @property + def size(self) -> int: + """ + Return the size in bytes of the given file. + + :return: File size in bytes + """ + size: int = self.location.size + return size + + def is_binary(self) -> bool: + """ + Return a constants.FileType given the filepath. Uses a naive strategy, using the file extension. + + :return: True or False + """ + result: bool = self.type.name == FileType.PARQUET + return result + + def is_pattern(self) -> bool: + """ + Returns True when file path is a pattern(eg. s3://bucket/folder or /folder/sample_* etc) + + :return: True or False + """ + return not pathlib.PosixPath(self.path).suffix + + def create_from_dataframe(self, df: pd.DataFrame) -> None: + """Create a file in the desired location using the values of a dataframe. + + :param df: pandas dataframe + """ + with smart_open.open(self.path, mode="wb", transport_params=self.location.transport_params) as stream: + self.type.create_from_dataframe(stream=stream, df=df) + + def export_to_dataframe(self, **kwargs) -> pd.DataFrame: + """Read file from all supported location and convert them into dataframes.""" + mode = "rb" if self.is_binary() else "r" + with smart_open.open(self.path, mode=mode, transport_params=self.location.transport_params) as stream: + return self.type.export_to_dataframe(stream, **kwargs) + + def _convert_remote_file_to_byte_stream(self) -> io.IOBase: + """ + Read file from all supported location and convert them into a buffer that can be streamed into other data + structures. + Due to noted issues with using smart_open with pandas (like + https://github.com/RaRe-Technologies/smart_open/issues/524), we create a BytesIO or StringIO buffer + before exporting to a dataframe. We've found a sizable speed improvement with this optimization + + :returns: an io object that can be streamed into a dataframe (or other object) + """ + + mode = "rb" if self.is_binary() else "r" + remote_obj_buffer = io.BytesIO() if self.is_binary() else io.StringIO() + with smart_open.open(self.path, mode=mode, transport_params=self.location.transport_params) as stream: + remote_obj_buffer.write(stream.read()) + remote_obj_buffer.seek(0) + return remote_obj_buffer + + def export_to_dataframe_via_byte_stream(self, **kwargs) -> pd.DataFrame: + """Read files from all supported locations and convert them into dataframes. + Due to noted issues with using smart_open with pandas (like + https://github.com/RaRe-Technologies/smart_open/issues/524), we create a BytesIO or StringIO buffer + before exporting to a dataframe. We've found a sizable speed improvement with this optimization. + """ + + return self.type.export_to_dataframe(self._convert_remote_file_to_byte_stream(), **kwargs) + + def exists(self) -> bool: + """Check if the file exists or not""" + file_exists: bool = self.location.exists() + return file_exists + + def __str__(self) -> str: + return self.path + + def __eq__(self, other) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + return self.location == other.location and self.type == other.type + + def __hash__(self) -> int: + return hash((self.path, self.conn_id, self.filetype)) diff --git a/src/universal_transfer_operator/datasets/table.py b/src/universal_transfer_operator/datasets/table.py new file mode 100644 index 0000000..68aedb8 --- /dev/null +++ b/src/universal_transfer_operator/datasets/table.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from urllib.parse import urlparse + +from attr import define, field, fields_dict +from sqlalchemy import Column + +from universal_transfer_operator.datasets.base import Dataset + + +@define +class Metadata: + """ + Contains additional information to access a SQL Table, which is very likely optional and, in some cases, may + be database-specific. + + :param schema: A schema name + :param database: A database name + """ + + # This property is used by several databases, including: Postgres, Snowflake and BigQuery ("namespace") + schema: str | None = None + database: str | None = None + + def is_empty(self) -> bool: + """Check if all the fields are None.""" + return all(getattr(self, field_name) is None for field_name in fields_dict(self.__class__)) + + +@define +class Table(Dataset): + """ + Repersents all Table datasets. + Intended to be used within library. + + :param name: The name of the database table. If name not provided then it would create a temporary name + :param conn_id: The Airflow connection id. This will be used to identify the right database type at the runtime + :param metadata: A metadata object which will have database or schema name + :param columns: columns which define the database table schema. + """ + + name: str = field(default="") + conn_id: str = field(default="") + # Setting converter allows passing a dictionary to metadata arg + metadata: Metadata = field( + factory=Metadata, + converter=lambda val: Metadata(**val) if isinstance(val, dict) else val, + ) + columns: list[Column] = field(factory=list) + uri: str = field(init=False) + extra: dict = field(init=True, factory=dict) + + @property + def sql_type(self): + raise NotImplementedError + + def exists(self): + """Check if the table exists or not""" + raise NotImplementedError + + def __str__(self) -> str: + return self.path + + def __hash__(self) -> int: + return hash((self.path, self.conn_id)) + + def dataset_scheme(self): + """ + Return the scheme based on path + """ + parsed = urlparse(self.path) + return parsed.scheme + + def dataset_namespace(self): + """ + The namespace of a dataset can be combined to form a URI (scheme:[//authority]path) + + Namespace = scheme:[//authority] (the dataset) + """ + parsed = urlparse(self.path) + namespace = f"{self.dataset_scheme()}://{parsed.netloc}" + return namespace + + def dataset_name(self): + """ + The name of a dataset can be combined to form a URI (scheme:[//authority]path) + + Name = path (the datasets) + """ + parsed = urlparse(self.path) + return parsed.path if self.path else self.name diff --git a/src/universal_transfer_operator/integrations/__init__.py b/src/universal_transfer_operator/integrations/__init__.py new file mode 100644 index 0000000..6335b98 --- /dev/null +++ b/src/universal_transfer_operator/integrations/__init__.py @@ -0,0 +1,30 @@ +import importlib + +from airflow.hooks.base import BaseHook + +from universal_transfer_operator.constants import IngestorSupported +from universal_transfer_operator.integrations.base import TransferIntegration, TransferIntegrationOptions +from universal_transfer_operator.utils import get_class_name + +CUSTOM_INGESTION_TYPE_TO_MODULE_PATH = {"fivetran": "universal_transfer_operator.integrations.fivetran"} + + +def get_transfer_integration(transfer_params: TransferIntegrationOptions) -> TransferIntegration: + """ + Given a transfer_params return the associated TransferIntegrations class. + + :param transfer_params: kwargs to be used by methods involved in transfer using FiveTran. + """ + thirdparty_conn_id = transfer_params.conn_id + + if thirdparty_conn_id is None: + raise ValueError("Connection id for integration is not specified.") + thirdparty_conn_type = BaseHook.get_connection(thirdparty_conn_id).conn_type + if thirdparty_conn_type not in {item.value for item in IngestorSupported}: + raise ValueError("Ingestion platform not yet supported.") + + module_path = CUSTOM_INGESTION_TYPE_TO_MODULE_PATH[thirdparty_conn_type] + module = importlib.import_module(module_path) + class_name = get_class_name(module_ref=module, suffix="Integration") + transfer_integrations: TransferIntegration = getattr(module, class_name)(transfer_params) + return transfer_integrations diff --git a/src/universal_transfer_operator/integrations/base.py b/src/universal_transfer_operator/integrations/base.py new file mode 100644 index 0000000..89d9a80 --- /dev/null +++ b/src/universal_transfer_operator/integrations/base.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +import attr +from airflow.hooks.base import BaseHook + +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.utils import TransferParameters + + +@attr.define +class TransferIntegrationOptions(TransferParameters): + """TransferIntegrationOptions for transfer integration configuration""" + + conn_id: str = attr.field(default="") + + +class TransferIntegration(ABC): + """Basic implementation of a third party transfer.""" + + def __init__( + self, + transfer_params: TransferIntegrationOptions = attr.field( + factory=TransferIntegrationOptions, + converter=lambda val: TransferIntegrationOptions(**val) if isinstance(val, dict) else val, + ), + ): + self.transfer_params = transfer_params + # transfer mapping creates a mapping between various sources and destination, where + # transfer is possible using the integration + self.transfer_mapping: dict[str, str] = None + # TODO: add method for validation, transfer mapping, transfer params etc + + @property + def hook(self) -> BaseHook: + """Return an instance of the database-specific Airflow hook.""" + raise NotImplementedError + + @abstractmethod + def transfer_job(self, source_dataset: Dataset, destination_dataset: Dataset) -> None: + """ + Loads data from source dataset to the destination using ingestion config + """ + raise NotImplementedError diff --git a/src/universal_transfer_operator/integrations/fivetran.py b/src/universal_transfer_operator/integrations/fivetran.py new file mode 100644 index 0000000..bd830e2 --- /dev/null +++ b/src/universal_transfer_operator/integrations/fivetran.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import logging +from functools import cached_property + +import attr +from attr import field +from fivetran_provider.hooks.fivetran import FivetranHook + +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.integrations.base import TransferIntegration, TransferIntegrationOptions + + +@attr.define +class Group: + """ + Fivetran group details + :param name: The name of the group within Fivetran account. + :param group_id : Group id in fivetran system + """ + + name: str + group_id: str | None = None + + +@attr.define +class Destination: + """ + Fivetran destination details + + :param destination_id: The unique identifier for the destination within the Fivetran system + :param service: services as per https://fivetran.com/docs/rest-api/destinations/config + :param region: Data processing location. This is where Fivetran will operate and run computation on data. + :param time_zone_offset: Determines the time zone for the Fivetran sync schedule. + :param run_setup_tests: Specifies whether setup tests should be run automatically. + :param config: Configuration as per destination specified at https://fivetran.com/docs/rest-api/destinations/config + """ + + service: str + config: dict + destination_id: str | None = None + time_zone_offset: str | None = "-5" + region: str | None = "GCP_US_EAST4" + run_setup_tests: bool | None = True + + +@attr.define +class Connector: + """ + Fivetran connector details + + :param connector_id: The unique identifier for the connector within the Fivetran system + :param service: services as per https://fivetran.com/docs/rest-api/destinations/config + :param config: Configuration as per destination specified at https://fivetran.com/docs/rest-api/connectors/config + :param paused: Specifies whether the connector is paused. Defaults to True. + :param pause_after_trial: Specifies whether the connector should be paused after the free trial period has ended. + Defaults to True. + :param sync_frequency: The connector sync frequency in minutes Enum: "5" "15" "30" "60" "120" "180" "360" "480" + "720" "1440". Default to "5" + :param daily_sync_time: Defines the sync start time when the sync frequency is already set or being set by the + current request to 1440. + :param schedule_type: Define the schedule type + :param connect_card_config: Connector card configuration + :param trust_certificates: Specifies whether we should trust the certificate automatically. The default value is + FALSE. If a certificate is not trusted automatically, + :param trust_fingerprints: Specifies whether we should trust the SSH fingerprint automatically. The default value + is FALSE. + :param run_setup_tests: Specifies whether the setup tests should be run automatically. The default value is TRUE. + """ + + connector_id: str | None + service: str + config: dict + connect_card_config: dict + paused: bool = True + pause_after_trial: bool = True + sync_frequency: str = "5" + daily_sync_time: str = "00:00" + schedule_type: str = "" + trust_certificates: bool = False + trust_fingerprints: bool = False + run_setup_tests: bool = True + + +@attr.define +class FiveTranOptions(TransferIntegrationOptions): + conn_id: str | None = field(default="fivetran_default") + connector_id: str | None = field(default="") + retry_limit: int = 3 + retry_delay: int = 1 + poll_interval: int = 15 + schedule_type: str = "manual" + connector: Connector | None = attr.field(default=None) + group: Group | None = attr.field(default=None) + destination: Destination | None = attr.field(default=None) + + +class FivetranIntegration(TransferIntegration): + """ + Fivetran integration to transfer datasets using Fivetran APIs. + """ + + api_user_agent = "airflow_provider_fivetran/1.1.3" + api_protocol = "https" + api_host = "api.fivetran.com" + api_path_connectors = "v1/connectors/" + api_path_groups = "v1/groups/" + api_path_destinations = "v1/destinations/" + + def __init__( + self, + transfer_params: FiveTranOptions = attr.field( + factory=FiveTranOptions, + converter=lambda val: FiveTranOptions(**val) if isinstance(val, dict) else val, + ), + ): + self.transfer_params = transfer_params + self.transfer_mapping = None + super().__init__(transfer_params=self.transfer_params) + + @cached_property + def hook(self) -> FivetranHook: + """Return an instance of the database-specific Airflow hook.""" + return FivetranHook( + self.transfer_params.conn_id, + retry_limit=self.transfer_params.retry_limit, + retry_delay=self.transfer_params.retry_delay, + ) + + def transfer_job(self, source_dataset: Dataset, destination_dataset: Dataset) -> None: + """ + Loads data from source dataset to the destination using ingestion config + """ + fivetran_hook = self.hook + + # Check if connector_id is passed and check if it exists and do the transfer. + connector_id = self.transfer_params.connector_id + if self.check_for_connector_id(fivetran_hook=fivetran_hook): + fivetran_hook.prep_connector( + connector_id=connector_id, schedule_type=self.transfer_params.schedule_type + ) + # TODO: wait until the job is done + return fivetran_hook.start_fivetran_sync(connector_id=connector_id) + + group_id = self.transfer_params.group.group_id + if not self.check_group_details(fivetran_hook=fivetran_hook, group_id=group_id): + # create group if not group_id is not passed. + group_id = self.create_group(fivetran_hook=fivetran_hook) + + destination_id = self.transfer_params.destination.destination_id + if not self.check_destination_details(fivetran_hook=fivetran_hook, destination_id=destination_id): + # Check for destination based on destination_id else create destination + self.create_destination(fivetran_hook=fivetran_hook, group_id=group_id) + + # Create connector if it doesn't exist + connector_id = self.create_connector(fivetran_hook=fivetran_hook, group_id=group_id) + + # Run connector setup test + self.run_connector_setup_tests(fivetran_hook=fivetran_hook, connector_id=connector_id) + + # Sync connector data + fivetran_hook.prep_connector(connector_id=connector_id, schedule_type=self.schedule_type) + return fivetran_hook.start_fivetran_sync(connector_id=connector_id) + + def check_for_connector_id(self, fivetran_hook: FivetranHook) -> bool: + """ + Ensures connector configuration has been completed successfully and is in a functional state. + """ + connector_id = self.transfer_params.connector_id + if connector_id is None: + logging.warning("No value specified for connector_id") + return False + + return fivetran_hook.check_connector(connector_id=connector_id) + + def check_group_details(self, fivetran_hook: FivetranHook, group_id: str | None) -> bool: + """ + Check if group_id is exists. + """ + + if group_id is None: + logging.warning( + "group_id is None. It should be the unique identifier for " + "the group within the Fivetran system. " + ) + return False + endpoint = self.api_path_groups + group_id + api_response = fivetran_hook._do_api_call(("GET", endpoint)) # skipcq: PYL-W0212 + if api_response["code"] == "Success": + logging.info("group_id {group_id} found.", extra={"group_id": group_id}) + else: + raise ValueError(api_response) + return True + + def create_group(self, fivetran_hook: FivetranHook) -> str: + """ + Creates the group based on group name passed + """ + endpoint = self.api_path_groups + group_dict = self.transfer_params.group + if group_dict is None: + raise ValueError("Group is none. Pass a valid group") + group = Group(**group_dict) + payload = {"name": group.name} + api_response = fivetran_hook._do_api_call(("POST", endpoint), json=payload) # skipcq: PYL-W0212 + if api_response["code"] == "Success": + logging.info(api_response) + else: + raise ValueError(api_response) + return api_response["data"]["id"] + + def check_destination_details(self, fivetran_hook: FivetranHook, destination_id: str | None) -> bool: + """ + Check if destination_id is exists. + """ + if destination_id is None: + logging.warning( + "destination_id is None. It should be the unique identifier for " + "the destination within the Fivetran system. " + ) + return False + endpoint = self.api_path_destinations + destination_id + api_response = fivetran_hook._do_api_call(("GET", endpoint)) # skipcq: PYL-W0212 + if api_response["code"] == "Success": + logging.info("destination_id {destination_id} found.", extra={"destination_id": destination_id}) + else: + raise ValueError(api_response) + return True + + def create_destination(self, fivetran_hook: FivetranHook, group_id: str) -> dict: + """ + Creates the destination based on destination configuration passed + """ + endpoint = self.api_path_destinations + destination_dict = self.transfer_params.destination + if destination_dict is None: + raise ValueError("destination is none. Pass a valid destination") + destination = Destination(**destination_dict) + payload = { + "group_id": group_id, + "service": destination.service, + "region": destination.region, + "time_zone_offset": destination.time_zone_offset, + "config": destination.config, + "run_setup_tests": destination.run_setup_tests, + } + api_response = fivetran_hook._do_api_call(("POST", endpoint), json=payload) # skipcq: PYL-W0212 + if api_response["code"] == "Success": + logging.info(api_response) + # TODO: parse all setup tests status for passed status + else: + raise ValueError(api_response) + return api_response + + def create_connector(self, fivetran_hook: FivetranHook, group_id: str) -> str: + """ + Creates the connector based on connector configuration passed + """ + endpoint = self.api_path_connectors + connector_dict = self.transfer_params.connector + if connector_dict is None: + raise ValueError("connector is none. Pass a valid connector") + + connector = Connector(**connector_dict) + payload = { + "group_id": group_id, + "service": connector.service, + "trust_certificates": connector.trust_certificates, + "trust_fingerprints": connector.trust_fingerprints, + "run_setup_tests": connector.run_setup_tests, + "paused": connector.paused, + "pause_after_trial": connector.pause_after_trial, + "sync_frequency": connector.sync_frequency, + "daily_sync_time": connector.daily_sync_time, + "schedule_type": connector.schedule_type, + "connect_card_config": connector.connect_card_config, + "config": connector.config, + } + api_response = fivetran_hook._do_api_call(("POST", endpoint), json=payload) # skipcq: PYL-W0212 + if api_response["code"] == "Success": + logging.info(api_response) + # TODO: parse all setup tests status for passed status + else: + raise ValueError(api_response) + return api_response["data"]["id"] + + def run_connector_setup_tests(self, fivetran_hook: FivetranHook, connector_id: str): + """ + Runs the setup tests for an existing connector within your Fivetran account. + """ + endpoint = self.api_path_connectors + connector_id + "/test" + connector_dict = self.transfer_params.connector + if connector_dict is None: + raise ValueError("connector is none. Pass a valid connector") + + connector = Connector(**connector_dict) + payload = { + "trust_certificates": connector.trust_certificates, + "trust_fingerprints": connector.trust_fingerprints, + } + api_response = fivetran_hook._do_api_call(("POST", endpoint), json=payload) # skipcq: PYL-W0212 + if api_response["code"] == "Success": + logging.info(api_response) + # TODO: parse all setup tests status for passed status + else: + raise ValueError(api_response) diff --git a/src/universal_transfer_operator/universal_transfer_operator.py b/src/universal_transfer_operator/universal_transfer_operator.py new file mode 100644 index 0000000..6af5393 --- /dev/null +++ b/src/universal_transfer_operator/universal_transfer_operator.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import Any + +import attr +from airflow.models import BaseOperator +from airflow.utils.context import Context + +from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.data_providers import create_dataprovider +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.integrations import get_transfer_integration +from universal_transfer_operator.utils import TransferParameters + + +class UniversalTransferOperator(BaseOperator): + """ + Transfers all the data that could be read from the source Dataset into the destination Dataset. From a DAG author + standpoint, all transfers would be performed through the invocation of only the Universal Transfer Operator. + + :param source_dataset: Source dataset to be transferred. + :param destination_dataset: Destination dataset to be transferred to. + :param transfer_params: kwargs to be used by method involved in transfer flow. + :param transfer_mode: Use transfer_mode TransferMode; native, non-native or thirdparty. + :param if_exists: Overwrite file if exists. Default False. + + :return: returns the destination dataset + """ + + def __init__( + self, + *, + source_dataset: Dataset, + destination_dataset: Dataset, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + transfer_mode: TransferMode = TransferMode.NONNATIVE, + **kwargs, + ) -> None: + self.source_dataset = source_dataset + self.destination_dataset = destination_dataset + self.transfer_mode = transfer_mode + # TODO: revisit names of transfer_mode + self.transfer_params = transfer_params + super().__init__(**kwargs) + + def execute(self, context: Context) -> Any: # skipcq: PYL-W0613 + if self.transfer_mode == TransferMode.THIRDPARTY: + transfer_integration = get_transfer_integration(self.transfer_params) + return transfer_integration.transfer_job(self.source_dataset, self.destination_dataset) + + source_dataprovider = create_dataprovider( + dataset=self.source_dataset, + transfer_params=self.transfer_params, + transfer_mode=self.transfer_mode, + ) + + destination_dataprovider = create_dataprovider( + dataset=self.destination_dataset, + transfer_params=self.transfer_params, + transfer_mode=self.transfer_mode, + ) + + with source_dataprovider.read() as source_data: + destination_data = destination_dataprovider.write(source_data) + + return destination_data diff --git a/src/universal_transfer_operator/utils.py b/src/universal_transfer_operator/utils.py new file mode 100644 index 0000000..70fdb86 --- /dev/null +++ b/src/universal_transfer_operator/utils.py @@ -0,0 +1,60 @@ +from typing import Any + +import attr +from airflow.hooks.base import BaseHook + +from universal_transfer_operator.datasets.base import Dataset + + +@attr.define +class TransferParameters: + if_exists: str = "replace" + + +def check_if_connection_exists(conn_id: str) -> bool: + """ + Given an Airflow connection ID, identify if it exists. + Return True if it does or raise an AirflowNotFoundException exception if it does not. + + :param conn_id: Airflow connection ID + :return bool: If the connection exists, return True + """ + try: + BaseHook.get_connection(conn_id) + except ValueError: + return False + return True + + +def get_dataset_connection_type(dataset: Dataset) -> str: + """ + Given dataset fetch the connection type based on airflow connection + """ + return BaseHook.get_connection(dataset.conn_id).conn_type + + +def get_class_name(module_ref: Any, suffix: str = "Location") -> str: + """Get class name to be dynamically imported. Class name are expected to be in following formats + example - + module name: test + suffix: Abc + + expected class names - + 1. TESTAbc + 2. TestAbc + + :param module_ref: Module from which to get class location type implementation + :param suffix: suffix for class name + """ + module_name = module_ref.__name__.split(".")[-1] + class_names_formats = [ + f"{module_name.title()}{suffix}", + f"{module_name.upper()}{suffix}", + ] + for class_names_format in class_names_formats: + if hasattr(module_ref, class_names_format): + return class_names_format + + raise ValueError( + "No expected class name found, please note that the class names should an expected formats." + )