Skip to content
This repository has been archived by the owner on Apr 11, 2024. It is now read-only.

Commit

Permalink
Add sftp dataprovider (#1802)
Browse files Browse the repository at this point in the history
# Description
## What is the current behavior?
Currently, we don't have data provider for sftp


closes: #1725 

## What is the new behavior?
1. Added SFTP data provider
2. Added test cases for read() and write() methods
3. Added SFTP server for CI
4. Added logic to add airflow connection to Pytest Session. 


## Does this introduce a breaking change?
Nope

### Checklist
- [ ] Created tests which fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
  • Loading branch information
utkarsharma2 committed Mar 2, 2023
1 parent 76b92b8 commit 8554233
Show file tree
Hide file tree
Showing 10 changed files with 432 additions and 4 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Expand Up @@ -86,13 +86,18 @@ all = [
"protobuf<=3.20", # Google bigquery client require protobuf <= 3.20.0. We can remove the limitation when this limitation is removed
"openlineage-airflow>=0.17.0",
"airflow-provider-fivetran>=1.1.3",
"apache-airflow-providers-sftp"
]
doc = [
"myst-parser>=0.17",
"sphinx>=4.4.0",
"sphinx-autoapi>=2.0.0",
"sphinx-rtd-theme"
]
sftp = [
"apache-airflow-providers-sftp",
"smart-open[ssh]>=5.2.1",
]

[project.urls]
Home = "https://astronomer.io/"
Expand Down
1 change: 1 addition & 0 deletions src/universal_transfer_operator/data_providers/__init__.py
Expand Up @@ -12,6 +12,7 @@
"aws": "universal_transfer_operator.data_providers.filesystem.aws.s3",
"gs": "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs",
"google_cloud_platform": "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs",
"sftp": "universal_transfer_operator.data_providers.filesystem.sftp",
}


Expand Down
Expand Up @@ -72,7 +72,7 @@ def transport_params(self) -> dict | None: # skipcq: PYL-R0201

def check_if_exists(self) -> bool:
"""Return true if the dataset exists"""
raise NotImplementedError
return False

def check_if_transfer_supported(self, source_dataset: Dataset) -> bool:
"""
Expand Down Expand Up @@ -107,8 +107,11 @@ def _convert_remote_file_to_byte_stream(self, file: str) -> io.IOBase:
remote_obj_buffer.seek(0)
return remote_obj_buffer

def write(self, source_ref):
"""Write the data from local reference location to the dataset"""
def write(self, source_ref: FileStream):
"""
Write the data from local reference location to the dataset
:param source_ref: Source FileStream object which will be used to read data
"""
return self.write_using_smart_open(source_ref=source_ref)

def write_using_smart_open(self, source_ref: FileStream):
Expand Down
134 changes: 134 additions & 0 deletions src/universal_transfer_operator/data_providers/filesystem/sftp.py
@@ -0,0 +1,134 @@
from __future__ import annotations

from functools import cached_property
from urllib.parse import ParseResult, urlparse, urlunparse

import attr
import smart_open
from airflow.providers.sftp.hooks.sftp import SFTPHook

from universal_transfer_operator.constants import Location, TransferMode
from universal_transfer_operator.data_providers.filesystem.base import (
BaseFilesystemProviders,
FileStream,
)
from universal_transfer_operator.datasets.file.base import File
from universal_transfer_operator.utils import TransferParameters


class SFTPDataProvider(BaseFilesystemProviders):
"""
DataProviders interactions with GS Dataset.
"""

def __init__(
self,
dataset: File,
transfer_params: TransferParameters = attr.field(
factory=TransferParameters,
converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val,
),
transfer_mode: TransferMode = TransferMode.NONNATIVE,
):
super().__init__(
dataset=dataset,
transfer_params=transfer_params,
transfer_mode=transfer_mode,
)
self.transfer_mapping = {
Location.S3,
Location.GS,
}

@cached_property
def hook(self) -> SFTPHook:
"""Return an instance of the SFTPHook Airflow hook."""
return SFTPHook(ssh_conn_id=self.dataset.conn_id)

@property
def paths(self) -> list[str]:
"""Resolve SFTP file paths with netloc of self.dataset.path as prefix. Paths are added if they start with prefix
Example - if there are multiple paths like
- sftp://upload/test.csv
- sftp://upload/test.json
- sftp://upload/home.parquet
- sftp://upload/sample.ndjson
if self.dataset.path is "sftp://upload/test" will return sftp://upload/test.csv and sftp://upload/test.json
"""
url = urlparse(self.dataset.path)
uri = self.get_uri()
full_paths = []
prefixes = self.hook.get_tree_map(url.netloc, prefix=url.netloc + url.path)
for keys in prefixes:
if len(keys) > 0:
full_paths.extend(keys)
# paths = ["/" + path for path in full_paths]
paths = [uri + "/" + path for path in full_paths]
return paths

@property
def transport_params(self) -> dict:
"""get SFTP credentials for storage"""
client = self.hook.get_connection(self.dataset.conn_id)
extra_options = client.extra_dejson
if "key_file" in extra_options:
key_file = extra_options.get("key_file")
return {"connect_kwargs": {"key_filename": key_file}}
elif client.password:
return {"connect_kwargs": {"password": client.password}}
raise ValueError("SFTP credentials are not set in the connection.")

def get_uri(self):
client = self.hook.get_connection(self.dataset.conn_id)
return client.get_uri()

@staticmethod
def _get_url_path(dst_url: ParseResult, src_url: ParseResult) -> str:
"""
Get correct file path, priority is given to destination file path.
:return: URL path
"""
path = dst_url.path if dst_url.__getattribute__("path") else src_url.path
return dst_url.hostname + path

def get_complete_url(self, dst_url: str, src_url: str) -> str:
"""
Get complete url with host, port, username, password if they are not provided in the `dst_url`
"""
complete_url = urlparse(self.get_uri())
_dst_url = urlparse(dst_url)
_src_url = urlparse(src_url)

path = self._get_url_path(dst_url=_dst_url, src_url=_src_url)

final_url = complete_url._replace(path=path)

return urlunparse(final_url)

def write_using_smart_open(self, source_ref: FileStream) -> str:
"""Write the source data from remote object i/o buffer to the dataset using smart open
:param source_ref: FileStream object of source dataset
:return: File path that is the used for write pattern
"""
mode = "wb" if self.read_as_binary(source_ref.actual_filename) else "w"
complete_url = self.get_complete_url(self.dataset.path, source_ref.actual_filename)
with smart_open.open(complete_url, mode=mode, transport_params=self.transport_params) as stream:
stream.write(source_ref.remote_obj_buffer.read())
return complete_url

@property
def openlineage_dataset_namespace(self) -> str:
"""
Returns the open lineage dataset namespace as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
raise NotImplementedError

@property
def openlineage_dataset_name(self) -> str:
"""
Returns the open lineage dataset name as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
raise NotImplementedError
36 changes: 35 additions & 1 deletion tests/conftest.py
@@ -1,13 +1,47 @@
import logging
import os

import pytest
from airflow.models import Connection, DagRun, TaskInstance as TI
import yaml
from airflow.models import DAG, Connection, DagRun, TaskInstance as TI
from airflow.utils import timezone
from airflow.utils.db import create_default_connections
from airflow.utils.session import create_session
from utils.test_utils import create_unique_str

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
UNIQUE_HASH_SIZE = 16


@pytest.fixture
def sample_dag():
dag_id = create_unique_str(UNIQUE_HASH_SIZE)
yield DAG(dag_id, start_date=DEFAULT_DATE)
with create_session() as session_:
session_.query(DagRun).delete()
session_.query(TI).delete()


@pytest.fixture(scope="session", autouse=True)
def create_database_connections():
with open(os.path.dirname(__file__) + "/../test-connections.yaml") as fp:
yaml_with_env = os.path.expandvars(fp.read())
yaml_dicts = yaml.safe_load(yaml_with_env)
connections = []
for i in yaml_dicts["connections"]:
connections.append(Connection(**i))
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()
session.query(Connection).delete()
create_default_connections(session)
for conn in connections:
last_conn = session.query(Connection).filter(Connection.conn_id == conn.conn_id).first()
if last_conn is not None:
session.delete(last_conn)
session.flush()
logging.info(
"Overriding existing conn_id %s with connection specified in test_connections.yaml",
conn.conn_id,
)
session.add(conn)
4 changes: 4 additions & 0 deletions tests/data/sample.csv
@@ -0,0 +1,4 @@
id,name
1,First
2,Second
3,Third with unicode पांचाल
2 changes: 2 additions & 0 deletions tests/test_data_provider/test_data_provider.py
Expand Up @@ -3,6 +3,7 @@
from universal_transfer_operator.data_providers import create_dataprovider
from universal_transfer_operator.data_providers.filesystem.aws.s3 import S3DataProvider
from universal_transfer_operator.data_providers.filesystem.google.cloud.gcs import GCSDataProvider
from universal_transfer_operator.data_providers.filesystem.sftp import SFTPDataProvider
from universal_transfer_operator.datasets.file.base import File


Expand All @@ -11,6 +12,7 @@
[
{"dataset": File("s3://astro-sdk-test/uto/", conn_id="aws_default"), "expected": S3DataProvider},
{"dataset": File("gs://uto-test/uto/", conn_id="google_cloud_default"), "expected": GCSDataProvider},
{"dataset": File("sftp://upload/sample.csv", conn_id="sftp_default"), "expected": SFTPDataProvider},
],
ids=lambda d: d["dataset"].conn_id,
)
Expand Down
64 changes: 64 additions & 0 deletions tests/test_data_provider/test_filesystem/test_sftp.py
@@ -0,0 +1,64 @@
import pathlib

import pandas as pd
from airflow.providers.sftp.hooks.sftp import SFTPHook
from utils.test_utils import create_unique_str

from universal_transfer_operator.data_providers import create_dataprovider
from universal_transfer_operator.data_providers.filesystem.base import FileStream
from universal_transfer_operator.datasets.file.base import File

CWD = pathlib.Path(__file__).parent
DATA_DIR = str(CWD) + "/../../data/"


def upload_file_to_sftp_server(conn_id: str, local_path: str, remote_path: str):
sftp = SFTPHook(ssh_conn_id=conn_id)
sftp.store_file(remote_full_path=remote_path, local_full_path=local_path)


def test_sftp_read():
"""
Test to validate working of SFTPDataProvider.read() method
"""
filepath = DATA_DIR + "sample.csv"
remote_path = f"/upload/{create_unique_str(10)}.csv"
upload_file_to_sftp_server(conn_id="sftp_conn", local_path=filepath, remote_path=remote_path)

dataprovider = create_dataprovider(dataset=File(path=f"sftp:/{remote_path}", conn_id="sftp_conn"))
iterator_obj = dataprovider.read()
source_data = iterator_obj.__next__()

sftp_df = pd.read_csv(source_data.remote_obj_buffer)
true_df = pd.read_csv(filepath)
assert sftp_df.equals(true_df)


def download_file_from_sftp(conn_id: str, local_path: str, remote_path: str):
sftp = SFTPHook(ssh_conn_id=conn_id)
sftp.retrieve_file(
local_full_path=local_path,
remote_full_path=remote_path,
)


def test_sftp_write():
"""
Test to validate working of SFTPDataProvider.write() method
"""
local_filepath = DATA_DIR + "sample.csv"
file_name = f"{create_unique_str(10)}.csv"
remote_filepath = f"sftp://upload/{file_name}"

dataprovider = create_dataprovider(dataset=File(path=remote_filepath, conn_id="sftp_conn"))
fs = FileStream(remote_obj_buffer=open(local_filepath), actual_filename=local_filepath)
dataprovider.write(source_ref=fs)

downloaded_file = f"/tmp/{file_name}"
download_file_from_sftp(
conn_id="sftp_conn", local_path=downloaded_file, remote_path=f"{remote_filepath.split('sftp:/')[1]}"
)

sftp_df = pd.read_csv(downloaded_file)
true_df = pd.read_csv(local_filepath)
assert sftp_df.equals(true_df)

0 comments on commit 8554233

Please sign in to comment.