Skip to content

Commit

Permalink
Close open connections for deferrable SFTPSensor (#38881)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Lee <weilee.rx@gmail.com>
  • Loading branch information
pankajkoti and Lee-W committed Apr 10, 2024
1 parent 0af5d92 commit d703a22
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
33 changes: 16 additions & 17 deletions airflow/providers/sftp/hooks/sftp.py
Expand Up @@ -516,25 +516,24 @@ async def _get_conn(self) -> asyncssh.SSHClientConnection:
ssh_client_conn = await asyncssh.connect(**conn_config)
return ssh_client_conn

async def list_directory(self, path: str = "") -> list[str] | None:
async def list_directory(self, path: str = "") -> list[str] | None: # type: ignore[return]
"""Return a list of files on the SFTP server at the provided path."""
ssh_conn = await self._get_conn()
sftp_client = await ssh_conn.start_sftp_client()
try:
files = await sftp_client.listdir(path)
return sorted(files)
except asyncssh.SFTPNoSuchFile:
return None

async def read_directory(self, path: str = "") -> Sequence[asyncssh.sftp.SFTPName] | None:
async with await self._get_conn() as ssh_conn:
sftp_client = await ssh_conn.start_sftp_client()
try:
files = await sftp_client.listdir(path)
return sorted(files)
except asyncssh.SFTPNoSuchFile:
return None

async def read_directory(self, path: str = "") -> Sequence[asyncssh.sftp.SFTPName] | None: # type: ignore[return]
"""Return a list of files along with their attributes on the SFTP server at the provided path."""
ssh_conn = await self._get_conn()
sftp_client = await ssh_conn.start_sftp_client()
try:
files = await sftp_client.readdir(path)
return files
except asyncssh.SFTPNoSuchFile:
return None
async with await self._get_conn() as ssh_conn:
sftp_client = await ssh_conn.start_sftp_client()
try:
return await sftp_client.readdir(path)
except asyncssh.SFTPNoSuchFile:
return None

async def get_files_and_attrs_by_pattern(
self, path: str = "", fnmatch_pattern: str = ""
Expand Down
13 changes: 9 additions & 4 deletions tests/providers/sftp/hooks/test_sftp.py
Expand Up @@ -712,52 +712,57 @@ async def test_list_directory_path_does_not_exist(self, mock_hook_get_conn):
"""
Assert that AirflowException is raised when path does not exist on SFTP server
"""
mock_hook_get_conn.return_value = MockSSHClient()
mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient()

hook = SFTPHookAsync()

expected_files = None
files = await hook.list_directory(path="/path/does_not/exist/")
assert files == expected_files
mock_hook_get_conn.return_value.__aexit__.assert_called()

@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
async def test_read_directory_path_does_not_exist(self, mock_hook_get_conn):
"""
Assert that AirflowException is raised when path does not exist on SFTP server
"""
mock_hook_get_conn.return_value = MockSSHClient()
mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient()
hook = SFTPHookAsync()

expected_files = None
files = await hook.read_directory(path="/path/does_not/exist/")
assert files == expected_files
mock_hook_get_conn.return_value.__aexit__.assert_called()

@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
async def test_list_directory_path_has_files(self, mock_hook_get_conn):
"""
Assert that file list is returned when path exists on SFTP server
"""
mock_hook_get_conn.return_value = MockSSHClient()
mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient()
hook = SFTPHookAsync()

expected_files = ["..", ".", "file"]
files = await hook.list_directory(path="/path/exists/")
assert sorted(files) == sorted(expected_files)
mock_hook_get_conn.return_value.__aexit__.assert_called()

@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
async def test_get_file_by_pattern_with_match(self, mock_hook_get_conn):
"""
Assert that filename is returned when file pattern is matched on SFTP server
"""
mock_hook_get_conn.return_value = MockSSHClient()
mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient()
hook = SFTPHookAsync()

files = await hook.get_files_and_attrs_by_pattern(path="/path/exists/", fnmatch_pattern="file")

assert len(files) == 1
assert files[0].filename == "file"
mock_hook_get_conn.return_value.__aexit__.assert_called()

@pytest.mark.asyncio
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
Expand Down

0 comments on commit d703a22

Please sign in to comment.