From bac7b3027d57d2a31acb9a2d078c6af4dc777162 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Fri, 20 Jan 2023 20:32:11 +0100 Subject: [PATCH] Fix SFTP operator's template fields processing (#29068) The SFTP operator had logic in `__init__` that was running checks on passed arguments, but that means that the operator woudl fail when using TaskFlow. Fixes: #27328 --- airflow/providers/sftp/operators/sftp.py | 42 +++++++++++---------- tests/providers/sftp/operators/test_sftp.py | 35 +++++++++++++---- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index 8884818ffb493..2bb2cee9f8271 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -100,23 +100,26 @@ def __init__( self.operation = operation self.confirm = confirm self.create_intermediate_dirs = create_intermediate_dirs + self.local_filepath = local_filepath + self.remote_filepath = remote_filepath - self.local_filepath_was_str = False - if isinstance(local_filepath, str): - self.local_filepath = [local_filepath] - self.local_filepath_was_str = True + def execute(self, context: Any) -> str | list[str] | None: + local_filepath_was_str = False + if isinstance(self.local_filepath, str): + local_filepath_array = [self.local_filepath] + local_filepath_was_str = True else: - self.local_filepath = local_filepath + local_filepath_array = self.local_filepath - if isinstance(remote_filepath, str): - self.remote_filepath = [remote_filepath] + if isinstance(self.remote_filepath, str): + remote_filepath_array = [self.remote_filepath] else: - self.remote_filepath = remote_filepath + remote_filepath_array = self.remote_filepath - if len(self.local_filepath) != len(self.remote_filepath): + if len(local_filepath_array) != len(remote_filepath_array): raise ValueError( - f"{len(self.local_filepath)} paths in local_filepath " - f"!= {len(self.remote_filepath)} paths in remote_filepath" + f"{len(local_filepath_array)} paths in local_filepath " + f"!= {len(remote_filepath_array)} paths in remote_filepath" ) if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT): @@ -145,7 +148,6 @@ def __init__( ) self.sftp_hook = SFTPHook(ssh_hook=self.ssh_hook) - def execute(self, context: Any) -> str | list[str] | None: file_msg = None try: if self.ssh_conn_id: @@ -168,23 +170,23 @@ def execute(self, context: Any) -> str | list[str] | None: ) self.sftp_hook.remote_host = self.remote_host - for local_filepath, remote_filepath in zip(self.local_filepath, self.remote_filepath): + for _local_filepath, _remote_filepath in zip(local_filepath_array, remote_filepath_array): if self.operation.lower() == SFTPOperation.GET: - local_folder = os.path.dirname(local_filepath) + local_folder = os.path.dirname(_local_filepath) if self.create_intermediate_dirs: Path(local_folder).mkdir(parents=True, exist_ok=True) - file_msg = f"from {remote_filepath} to {local_filepath}" + file_msg = f"from {_remote_filepath} to {_local_filepath}" self.log.info("Starting to transfer %s", file_msg) - self.sftp_hook.retrieve_file(remote_filepath, local_filepath) + self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath) else: - remote_folder = os.path.dirname(remote_filepath) + remote_folder = os.path.dirname(_remote_filepath) if self.create_intermediate_dirs: self.sftp_hook.create_directory(remote_folder) - file_msg = f"from {local_filepath} to {remote_filepath}" + file_msg = f"from {_local_filepath} to {_remote_filepath}" self.log.info("Starting to transfer file %s", file_msg) - self.sftp_hook.store_file(remote_filepath, local_filepath, confirm=self.confirm) + self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm) except Exception as e: raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}") - return self.local_filepath[0] if self.local_filepath_was_str else self.local_filepath + return local_filepath_array[0] if local_filepath_was_str else local_filepath_array diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index 73d66d393ccfd..92137e2a5c183 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -396,16 +396,22 @@ def test_unequal_local_remote_file_paths(self): task_id="test_sftp_unequal_paths", local_filepath="/tmp/test", remote_filepath=["/tmp/test1", "/tmp/test2"], - ) + ).execute(None) - def test_str_filepaths_converted_to_lists(self): + @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file") + def test_str_filepaths_get(self, mock_get): local_filepath = "/tmp/test" remote_filepath = "/tmp/remotetest" - sftp_op = SFTPOperator( - task_id="test_str_to_list", local_filepath=local_filepath, remote_filepath=remote_filepath - ) - assert sftp_op.local_filepath == [local_filepath] - assert sftp_op.remote_filepath == [remote_filepath] + SFTPOperator( + task_id="test_str_to_list", + sftp_hook=self.sftp_hook, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.GET, + ).execute(None) + assert mock_get.call_count == 1 + args, _ = mock_get.call_args_list[0] + assert args == (remote_filepath, local_filepath) @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file") def test_multiple_paths_get(self, mock_get): @@ -425,6 +431,21 @@ def test_multiple_paths_get(self, mock_get): assert args0 == (remote_filepath[0], local_filepath[0]) assert args1 == (remote_filepath[1], local_filepath[1]) + @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file") + def test_str_filepaths_put(self, mock_get): + local_filepath = "/tmp/test" + remote_filepath = "/tmp/remotetest" + SFTPOperator( + task_id="test_str_to_list", + sftp_hook=self.sftp_hook, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.PUT, + ).execute(None) + assert mock_get.call_count == 1 + args, _ = mock_get.call_args_list[0] + assert args == (remote_filepath, local_filepath) + @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file") def test_multiple_paths_put(self, mock_put): local_filepath = ["/tmp/ltest1", "/tmp/ltest2"]