Skip to content

Commit

Permalink
Fix SFTP operator's template fields processing (#29068)
Browse files Browse the repository at this point in the history
The SFTP operator had logic in `__init__` that was running checks
on passed arguments, but that means that the operator woudl fail
when using TaskFlow.

Fixes: #27328
  • Loading branch information
potiuk committed Jan 20, 2023
1 parent 10f0f8b commit bac7b30
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 27 deletions.
42 changes: 22 additions & 20 deletions airflow/providers/sftp/operators/sftp.py
Expand Up @@ -100,23 +100,26 @@ def __init__(
self.operation = operation
self.confirm = confirm
self.create_intermediate_dirs = create_intermediate_dirs
self.local_filepath = local_filepath
self.remote_filepath = remote_filepath

self.local_filepath_was_str = False
if isinstance(local_filepath, str):
self.local_filepath = [local_filepath]
self.local_filepath_was_str = True
def execute(self, context: Any) -> str | list[str] | None:
local_filepath_was_str = False
if isinstance(self.local_filepath, str):
local_filepath_array = [self.local_filepath]
local_filepath_was_str = True
else:
self.local_filepath = local_filepath
local_filepath_array = self.local_filepath

if isinstance(remote_filepath, str):
self.remote_filepath = [remote_filepath]
if isinstance(self.remote_filepath, str):
remote_filepath_array = [self.remote_filepath]
else:
self.remote_filepath = remote_filepath
remote_filepath_array = self.remote_filepath

if len(self.local_filepath) != len(self.remote_filepath):
if len(local_filepath_array) != len(remote_filepath_array):
raise ValueError(
f"{len(self.local_filepath)} paths in local_filepath "
f"!= {len(self.remote_filepath)} paths in remote_filepath"
f"{len(local_filepath_array)} paths in local_filepath "
f"!= {len(remote_filepath_array)} paths in remote_filepath"
)

if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT):
Expand Down Expand Up @@ -145,7 +148,6 @@ def __init__(
)
self.sftp_hook = SFTPHook(ssh_hook=self.ssh_hook)

def execute(self, context: Any) -> str | list[str] | None:
file_msg = None
try:
if self.ssh_conn_id:
Expand All @@ -168,23 +170,23 @@ def execute(self, context: Any) -> str | list[str] | None:
)
self.sftp_hook.remote_host = self.remote_host

for local_filepath, remote_filepath in zip(self.local_filepath, self.remote_filepath):
for _local_filepath, _remote_filepath in zip(local_filepath_array, remote_filepath_array):
if self.operation.lower() == SFTPOperation.GET:
local_folder = os.path.dirname(local_filepath)
local_folder = os.path.dirname(_local_filepath)
if self.create_intermediate_dirs:
Path(local_folder).mkdir(parents=True, exist_ok=True)
file_msg = f"from {remote_filepath} to {local_filepath}"
file_msg = f"from {_remote_filepath} to {_local_filepath}"
self.log.info("Starting to transfer %s", file_msg)
self.sftp_hook.retrieve_file(remote_filepath, local_filepath)
self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath)
else:
remote_folder = os.path.dirname(remote_filepath)
remote_folder = os.path.dirname(_remote_filepath)
if self.create_intermediate_dirs:
self.sftp_hook.create_directory(remote_folder)
file_msg = f"from {local_filepath} to {remote_filepath}"
file_msg = f"from {_local_filepath} to {_remote_filepath}"
self.log.info("Starting to transfer file %s", file_msg)
self.sftp_hook.store_file(remote_filepath, local_filepath, confirm=self.confirm)
self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm)

except Exception as e:
raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}")

return self.local_filepath[0] if self.local_filepath_was_str else self.local_filepath
return local_filepath_array[0] if local_filepath_was_str else local_filepath_array
35 changes: 28 additions & 7 deletions tests/providers/sftp/operators/test_sftp.py
Expand Up @@ -396,16 +396,22 @@ def test_unequal_local_remote_file_paths(self):
task_id="test_sftp_unequal_paths",
local_filepath="/tmp/test",
remote_filepath=["/tmp/test1", "/tmp/test2"],
)
).execute(None)

def test_str_filepaths_converted_to_lists(self):
@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file")
def test_str_filepaths_get(self, mock_get):
local_filepath = "/tmp/test"
remote_filepath = "/tmp/remotetest"
sftp_op = SFTPOperator(
task_id="test_str_to_list", local_filepath=local_filepath, remote_filepath=remote_filepath
)
assert sftp_op.local_filepath == [local_filepath]
assert sftp_op.remote_filepath == [remote_filepath]
SFTPOperator(
task_id="test_str_to_list",
sftp_hook=self.sftp_hook,
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=SFTPOperation.GET,
).execute(None)
assert mock_get.call_count == 1
args, _ = mock_get.call_args_list[0]
assert args == (remote_filepath, local_filepath)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file")
def test_multiple_paths_get(self, mock_get):
Expand All @@ -425,6 +431,21 @@ def test_multiple_paths_get(self, mock_get):
assert args0 == (remote_filepath[0], local_filepath[0])
assert args1 == (remote_filepath[1], local_filepath[1])

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file")
def test_str_filepaths_put(self, mock_get):
local_filepath = "/tmp/test"
remote_filepath = "/tmp/remotetest"
SFTPOperator(
task_id="test_str_to_list",
sftp_hook=self.sftp_hook,
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=SFTPOperation.PUT,
).execute(None)
assert mock_get.call_count == 1
args, _ = mock_get.call_args_list[0]
assert args == (remote_filepath, local_filepath)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file")
def test_multiple_paths_put(self, mock_put):
local_filepath = ["/tmp/ltest1", "/tmp/ltest2"]
Expand Down

0 comments on commit bac7b30

Please sign in to comment.