Skip to content

Commit

Permalink
Add parameter sftp_prefetch to SFTPToGCSOperator (#33274)
Browse files Browse the repository at this point in the history
  • Loading branch information
functicons committed Aug 19, 2023
1 parent de17b93 commit 533afb5
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 9 deletions.
5 changes: 4 additions & 1 deletion airflow/providers/google/cloud/transfers/sftp_to_gcs.py
Expand Up @@ -69,6 +69,7 @@ class SFTPToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param sftp_prefetch: Whether to enable SFTP prefetch, the default is True.
"""

template_fields: Sequence[str] = (
Expand All @@ -90,6 +91,7 @@ def __init__(
gzip: bool = False,
move_object: bool = False,
impersonation_chain: str | Sequence[str] | None = None,
sftp_prefetch: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -103,6 +105,7 @@ def __init__(
self.sftp_conn_id = sftp_conn_id
self.move_object = move_object
self.impersonation_chain = impersonation_chain
self.sftp_prefetch = sftp_prefetch

def execute(self, context: Context):
gcs_hook = GCSHook(
Expand Down Expand Up @@ -151,7 +154,7 @@ def _copy_single_object(
)

with NamedTemporaryFile("w") as tmp:
sftp_hook.retrieve_file(source_path, tmp.name)
sftp_hook.retrieve_file(source_path, tmp.name, prefetch=self.sftp_prefetch)

gcs_hook.upload(
bucket_name=self.destination_bucket,
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/sftp/hooks/sftp.py
Expand Up @@ -223,17 +223,18 @@ def delete_directory(self, path: str) -> None:
conn = self.get_conn()
conn.rmdir(path)

def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None:
"""Transfer the remote file to a local location.
If local_full_path is a string path, the file will be put
at that location.
:param remote_full_path: full path to the remote file
:param local_full_path: full path to the local file
:param prefetch: controls whether prefetch is performed (default: True)
"""
conn = self.get_conn()
conn.get(remote_full_path, local_full_path)
conn.get(remote_full_path, local_full_path, prefetch=prefetch)

def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None:
"""Transfer a local file to the remote location.
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/sftp/provider.yaml
Expand Up @@ -54,6 +54,7 @@ versions:
dependencies:
- apache-airflow>=2.4.0
- apache-airflow-providers-ssh>=2.1.0
- paramiko>=2.8.0

integrations:
- integration-name: SSH File Transfer Protocol (SFTP)
Expand Down
3 changes: 2 additions & 1 deletion generated/provider_dependencies.json
Expand Up @@ -795,7 +795,8 @@
"sftp": {
"deps": [
"apache-airflow-providers-ssh>=2.1.0",
"apache-airflow>=2.4.0"
"apache-airflow>=2.4.0",
"paramiko>=2.8.0"
],
"cross-providers-deps": [
"openlineage",
Expand Down
12 changes: 7 additions & 5 deletions tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
Expand Up @@ -73,7 +73,7 @@ def test_execute_copy_single_file(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)

sftp_hook.return_value.retrieve_file.assert_called_once_with(
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=True
)

gcs_hook.return_value.upload.assert_called_once_with(
Expand All @@ -99,6 +99,7 @@ def test_execute_copy_single_file_with_compression(self, sftp_hook, gcs_hook):
sftp_conn_id=SFTP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
gzip=True,
sftp_prefetch=False,
)
task.execute(None)
gcs_hook.assert_called_once_with(
Expand All @@ -108,7 +109,7 @@ def test_execute_copy_single_file_with_compression(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)

sftp_hook.return_value.retrieve_file.assert_called_once_with(
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=False
)

gcs_hook.return_value.upload.assert_called_once_with(
Expand All @@ -133,6 +134,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook):
gcp_conn_id=GCP_CONN_ID,
sftp_conn_id=SFTP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
sftp_prefetch=True,
)
task.execute(None)
gcs_hook.assert_called_once_with(
Expand All @@ -142,7 +144,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)

sftp_hook.return_value.retrieve_file.assert_called_once_with(
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=True
)

gcs_hook.return_value.upload.assert_called_once_with(
Expand Down Expand Up @@ -181,8 +183,8 @@ def test_execute_copy_with_wildcard(self, sftp_hook, gcs_hook):

sftp_hook.return_value.retrieve_file.assert_has_calls(
[
mock.call("main_dir/test_object3.json", mock.ANY),
mock.call("main_dir/sub_dir/test_object3.json", mock.ANY),
mock.call("main_dir/test_object3.json", mock.ANY, prefetch=True),
mock.call("main_dir/sub_dir/test_object3.json", mock.ANY, prefetch=True),
]
)

Expand Down

0 comments on commit 533afb5

Please sign in to comment.