From 533afb5128383958889bc653226f46947c642351 Mon Sep 17 00:00:00 2001 From: Dagang Wei Date: Sat, 19 Aug 2023 08:35:51 -0700 Subject: [PATCH] Add parameter sftp_prefetch to SFTPToGCSOperator (#33274) --- .../providers/google/cloud/transfers/sftp_to_gcs.py | 5 ++++- airflow/providers/sftp/hooks/sftp.py | 5 +++-- airflow/providers/sftp/provider.yaml | 1 + generated/provider_dependencies.json | 3 ++- .../google/cloud/transfers/test_sftp_to_gcs.py | 12 +++++++----- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/airflow/providers/google/cloud/transfers/sftp_to_gcs.py b/airflow/providers/google/cloud/transfers/sftp_to_gcs.py index 394ea3e8a2003..eee8bd14b5a09 100644 --- a/airflow/providers/google/cloud/transfers/sftp_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sftp_to_gcs.py @@ -69,6 +69,7 @@ class SFTPToGCSOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param sftp_prefetch: Whether to enable SFTP prefetch, the default is True. """ template_fields: Sequence[str] = ( @@ -90,6 +91,7 @@ def __init__( gzip: bool = False, move_object: bool = False, impersonation_chain: str | Sequence[str] | None = None, + sftp_prefetch: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -103,6 +105,7 @@ def __init__( self.sftp_conn_id = sftp_conn_id self.move_object = move_object self.impersonation_chain = impersonation_chain + self.sftp_prefetch = sftp_prefetch def execute(self, context: Context): gcs_hook = GCSHook( @@ -151,7 +154,7 @@ def _copy_single_object( ) with NamedTemporaryFile("w") as tmp: - sftp_hook.retrieve_file(source_path, tmp.name) + sftp_hook.retrieve_file(source_path, tmp.name, prefetch=self.sftp_prefetch) gcs_hook.upload( bucket_name=self.destination_bucket, diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py index f1c3fbc0cc10e..e02872b4b03b5 100644 --- a/airflow/providers/sftp/hooks/sftp.py +++ b/airflow/providers/sftp/hooks/sftp.py @@ -223,7 +223,7 @@ def delete_directory(self, path: str) -> None: conn = self.get_conn() conn.rmdir(path) - def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None: + def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: """Transfer the remote file to a local location. If local_full_path is a string path, the file will be put @@ -231,9 +231,10 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None: :param remote_full_path: full path to the remote file :param local_full_path: full path to the local file + :param prefetch: controls whether prefetch is performed (default: True) """ conn = self.get_conn() - conn.get(remote_full_path, local_full_path) + conn.get(remote_full_path, local_full_path, prefetch=prefetch) def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None: """Transfer a local file to the remote location. diff --git a/airflow/providers/sftp/provider.yaml b/airflow/providers/sftp/provider.yaml index ec4df7ed98c04..8dfdaea911d66 100644 --- a/airflow/providers/sftp/provider.yaml +++ b/airflow/providers/sftp/provider.yaml @@ -54,6 +54,7 @@ versions: dependencies: - apache-airflow>=2.4.0 - apache-airflow-providers-ssh>=2.1.0 + - paramiko>=2.8.0 integrations: - integration-name: SSH File Transfer Protocol (SFTP) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index b96c6baf1cc38..8b8a435062864 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -795,7 +795,8 @@ "sftp": { "deps": [ "apache-airflow-providers-ssh>=2.1.0", - "apache-airflow>=2.4.0" + "apache-airflow>=2.4.0", + "paramiko>=2.8.0" ], "cross-providers-deps": [ "openlineage", diff --git a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py index 433961eb6e904..b6790c70f9da6 100644 --- a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py @@ -73,7 +73,7 @@ def test_execute_copy_single_file(self, sftp_hook, gcs_hook): sftp_hook.assert_called_once_with(SFTP_CONN_ID) sftp_hook.return_value.retrieve_file.assert_called_once_with( - os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY + os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=True ) gcs_hook.return_value.upload.assert_called_once_with( @@ -99,6 +99,7 @@ def test_execute_copy_single_file_with_compression(self, sftp_hook, gcs_hook): sftp_conn_id=SFTP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, gzip=True, + sftp_prefetch=False, ) task.execute(None) gcs_hook.assert_called_once_with( @@ -108,7 +109,7 @@ def test_execute_copy_single_file_with_compression(self, sftp_hook, gcs_hook): sftp_hook.assert_called_once_with(SFTP_CONN_ID) sftp_hook.return_value.retrieve_file.assert_called_once_with( - os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY + os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=False ) gcs_hook.return_value.upload.assert_called_once_with( @@ -133,6 +134,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook): gcp_conn_id=GCP_CONN_ID, sftp_conn_id=SFTP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + sftp_prefetch=True, ) task.execute(None) gcs_hook.assert_called_once_with( @@ -142,7 +144,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook): sftp_hook.assert_called_once_with(SFTP_CONN_ID) sftp_hook.return_value.retrieve_file.assert_called_once_with( - os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY + os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=True ) gcs_hook.return_value.upload.assert_called_once_with( @@ -181,8 +183,8 @@ def test_execute_copy_with_wildcard(self, sftp_hook, gcs_hook): sftp_hook.return_value.retrieve_file.assert_has_calls( [ - mock.call("main_dir/test_object3.json", mock.ANY), - mock.call("main_dir/sub_dir/test_object3.json", mock.ANY), + mock.call("main_dir/test_object3.json", mock.ANY, prefetch=True), + mock.call("main_dir/sub_dir/test_object3.json", mock.ANY, prefetch=True), ] )