This repository has been archived by the owner on Apr 11, 2024. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
# Description ## What is the current behavior? Currently, we don't have data provider for sftp closes: #1725 ## What is the new behavior? 1. Added SFTP data provider 2. Added test cases for read() and write() methods 3. Added SFTP server for CI 4. Added logic to add airflow connection to Pytest Session. ## Does this introduce a breaking change? Nope ### Checklist - [ ] Created tests which fail without the change (if possible) - [ ] Extended the README / documentation, if necessary
- Loading branch information
1 parent
76b92b8
commit 8554233
Showing
10 changed files
with
432 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from __future__ import annotations | ||
|
||
from functools import cached_property | ||
from urllib.parse import ParseResult, urlparse, urlunparse | ||
|
||
import attr | ||
import smart_open | ||
from airflow.providers.sftp.hooks.sftp import SFTPHook | ||
|
||
from universal_transfer_operator.constants import Location, TransferMode | ||
from universal_transfer_operator.data_providers.filesystem.base import ( | ||
BaseFilesystemProviders, | ||
FileStream, | ||
) | ||
from universal_transfer_operator.datasets.file.base import File | ||
from universal_transfer_operator.utils import TransferParameters | ||
|
||
|
||
class SFTPDataProvider(BaseFilesystemProviders): | ||
""" | ||
DataProviders interactions with GS Dataset. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: File, | ||
transfer_params: TransferParameters = attr.field( | ||
factory=TransferParameters, | ||
converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, | ||
), | ||
transfer_mode: TransferMode = TransferMode.NONNATIVE, | ||
): | ||
super().__init__( | ||
dataset=dataset, | ||
transfer_params=transfer_params, | ||
transfer_mode=transfer_mode, | ||
) | ||
self.transfer_mapping = { | ||
Location.S3, | ||
Location.GS, | ||
} | ||
|
||
@cached_property | ||
def hook(self) -> SFTPHook: | ||
"""Return an instance of the SFTPHook Airflow hook.""" | ||
return SFTPHook(ssh_conn_id=self.dataset.conn_id) | ||
|
||
@property | ||
def paths(self) -> list[str]: | ||
"""Resolve SFTP file paths with netloc of self.dataset.path as prefix. Paths are added if they start with prefix | ||
Example - if there are multiple paths like | ||
- sftp://upload/test.csv | ||
- sftp://upload/test.json | ||
- sftp://upload/home.parquet | ||
- sftp://upload/sample.ndjson | ||
if self.dataset.path is "sftp://upload/test" will return sftp://upload/test.csv and sftp://upload/test.json | ||
""" | ||
url = urlparse(self.dataset.path) | ||
uri = self.get_uri() | ||
full_paths = [] | ||
prefixes = self.hook.get_tree_map(url.netloc, prefix=url.netloc + url.path) | ||
for keys in prefixes: | ||
if len(keys) > 0: | ||
full_paths.extend(keys) | ||
# paths = ["/" + path for path in full_paths] | ||
paths = [uri + "/" + path for path in full_paths] | ||
return paths | ||
|
||
@property | ||
def transport_params(self) -> dict: | ||
"""get SFTP credentials for storage""" | ||
client = self.hook.get_connection(self.dataset.conn_id) | ||
extra_options = client.extra_dejson | ||
if "key_file" in extra_options: | ||
key_file = extra_options.get("key_file") | ||
return {"connect_kwargs": {"key_filename": key_file}} | ||
elif client.password: | ||
return {"connect_kwargs": {"password": client.password}} | ||
raise ValueError("SFTP credentials are not set in the connection.") | ||
|
||
def get_uri(self): | ||
client = self.hook.get_connection(self.dataset.conn_id) | ||
return client.get_uri() | ||
|
||
@staticmethod | ||
def _get_url_path(dst_url: ParseResult, src_url: ParseResult) -> str: | ||
""" | ||
Get correct file path, priority is given to destination file path. | ||
:return: URL path | ||
""" | ||
path = dst_url.path if dst_url.__getattribute__("path") else src_url.path | ||
return dst_url.hostname + path | ||
|
||
def get_complete_url(self, dst_url: str, src_url: str) -> str: | ||
""" | ||
Get complete url with host, port, username, password if they are not provided in the `dst_url` | ||
""" | ||
complete_url = urlparse(self.get_uri()) | ||
_dst_url = urlparse(dst_url) | ||
_src_url = urlparse(src_url) | ||
|
||
path = self._get_url_path(dst_url=_dst_url, src_url=_src_url) | ||
|
||
final_url = complete_url._replace(path=path) | ||
|
||
return urlunparse(final_url) | ||
|
||
def write_using_smart_open(self, source_ref: FileStream) -> str: | ||
"""Write the source data from remote object i/o buffer to the dataset using smart open | ||
:param source_ref: FileStream object of source dataset | ||
:return: File path that is the used for write pattern | ||
""" | ||
mode = "wb" if self.read_as_binary(source_ref.actual_filename) else "w" | ||
complete_url = self.get_complete_url(self.dataset.path, source_ref.actual_filename) | ||
with smart_open.open(complete_url, mode=mode, transport_params=self.transport_params) as stream: | ||
stream.write(source_ref.remote_obj_buffer.read()) | ||
return complete_url | ||
|
||
@property | ||
def openlineage_dataset_namespace(self) -> str: | ||
""" | ||
Returns the open lineage dataset namespace as per | ||
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md | ||
""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def openlineage_dataset_name(self) -> str: | ||
""" | ||
Returns the open lineage dataset name as per | ||
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,47 @@ | ||
import logging | ||
import os | ||
|
||
import pytest | ||
from airflow.models import Connection, DagRun, TaskInstance as TI | ||
import yaml | ||
from airflow.models import DAG, Connection, DagRun, TaskInstance as TI | ||
from airflow.utils import timezone | ||
from airflow.utils.db import create_default_connections | ||
from airflow.utils.session import create_session | ||
from utils.test_utils import create_unique_str | ||
|
||
DEFAULT_DATE = timezone.datetime(2016, 1, 1) | ||
UNIQUE_HASH_SIZE = 16 | ||
|
||
|
||
@pytest.fixture | ||
def sample_dag(): | ||
dag_id = create_unique_str(UNIQUE_HASH_SIZE) | ||
yield DAG(dag_id, start_date=DEFAULT_DATE) | ||
with create_session() as session_: | ||
session_.query(DagRun).delete() | ||
session_.query(TI).delete() | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
def create_database_connections(): | ||
with open(os.path.dirname(__file__) + "/../test-connections.yaml") as fp: | ||
yaml_with_env = os.path.expandvars(fp.read()) | ||
yaml_dicts = yaml.safe_load(yaml_with_env) | ||
connections = [] | ||
for i in yaml_dicts["connections"]: | ||
connections.append(Connection(**i)) | ||
with create_session() as session: | ||
session.query(DagRun).delete() | ||
session.query(TI).delete() | ||
session.query(Connection).delete() | ||
create_default_connections(session) | ||
for conn in connections: | ||
last_conn = session.query(Connection).filter(Connection.conn_id == conn.conn_id).first() | ||
if last_conn is not None: | ||
session.delete(last_conn) | ||
session.flush() | ||
logging.info( | ||
"Overriding existing conn_id %s with connection specified in test_connections.yaml", | ||
conn.conn_id, | ||
) | ||
session.add(conn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
id,name | ||
1,First | ||
2,Second | ||
3,Third with unicode पांचाल |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pathlib | ||
|
||
import pandas as pd | ||
from airflow.providers.sftp.hooks.sftp import SFTPHook | ||
from utils.test_utils import create_unique_str | ||
|
||
from universal_transfer_operator.data_providers import create_dataprovider | ||
from universal_transfer_operator.data_providers.filesystem.base import FileStream | ||
from universal_transfer_operator.datasets.file.base import File | ||
|
||
CWD = pathlib.Path(__file__).parent | ||
DATA_DIR = str(CWD) + "/../../data/" | ||
|
||
|
||
def upload_file_to_sftp_server(conn_id: str, local_path: str, remote_path: str): | ||
sftp = SFTPHook(ssh_conn_id=conn_id) | ||
sftp.store_file(remote_full_path=remote_path, local_full_path=local_path) | ||
|
||
|
||
def test_sftp_read(): | ||
""" | ||
Test to validate working of SFTPDataProvider.read() method | ||
""" | ||
filepath = DATA_DIR + "sample.csv" | ||
remote_path = f"/upload/{create_unique_str(10)}.csv" | ||
upload_file_to_sftp_server(conn_id="sftp_conn", local_path=filepath, remote_path=remote_path) | ||
|
||
dataprovider = create_dataprovider(dataset=File(path=f"sftp:/{remote_path}", conn_id="sftp_conn")) | ||
iterator_obj = dataprovider.read() | ||
source_data = iterator_obj.__next__() | ||
|
||
sftp_df = pd.read_csv(source_data.remote_obj_buffer) | ||
true_df = pd.read_csv(filepath) | ||
assert sftp_df.equals(true_df) | ||
|
||
|
||
def download_file_from_sftp(conn_id: str, local_path: str, remote_path: str): | ||
sftp = SFTPHook(ssh_conn_id=conn_id) | ||
sftp.retrieve_file( | ||
local_full_path=local_path, | ||
remote_full_path=remote_path, | ||
) | ||
|
||
|
||
def test_sftp_write(): | ||
""" | ||
Test to validate working of SFTPDataProvider.write() method | ||
""" | ||
local_filepath = DATA_DIR + "sample.csv" | ||
file_name = f"{create_unique_str(10)}.csv" | ||
remote_filepath = f"sftp://upload/{file_name}" | ||
|
||
dataprovider = create_dataprovider(dataset=File(path=remote_filepath, conn_id="sftp_conn")) | ||
fs = FileStream(remote_obj_buffer=open(local_filepath), actual_filename=local_filepath) | ||
dataprovider.write(source_ref=fs) | ||
|
||
downloaded_file = f"/tmp/{file_name}" | ||
download_file_from_sftp( | ||
conn_id="sftp_conn", local_path=downloaded_file, remote_path=f"{remote_filepath.split('sftp:/')[1]}" | ||
) | ||
|
||
sftp_df = pd.read_csv(downloaded_file) | ||
true_df = pd.read_csv(local_filepath) | ||
assert sftp_df.equals(true_df) |
Oops, something went wrong.