Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions airflow/contrib/hooks/sftp_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import datetime
import stat
from typing import Dict, List
from typing import Dict, List, Optional, Tuple

import pysftp

Expand Down Expand Up @@ -223,7 +223,7 @@ def get_mod_time(self, path: str) -> str:
ftp_mdtm = conn.stat(path).st_mtime
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')

def path_exists(self, path):
def path_exists(self, path: str) -> bool:
"""
Returns True if a remote entity exists

Expand All @@ -232,3 +232,58 @@ def path_exists(self, path):
"""
conn = self.get_conn()
return conn.exists(path)

@staticmethod
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
"""
Return True if given path starts with prefix (if set) and ends with delimiter (if set).

:param path: path to be checked
:type path: str
:param prefix: if set path will be checked is starting with prefix
:type prefix: str
:param delimiter: if set path will be checked is ending with suffix
:type delimiter: str
:return: bool
"""
if prefix is not None and not path.startswith(prefix):
return False
if delimiter is not None and not path.endswith(delimiter):
return False
return True

def get_tree_map(
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
) -> Tuple[List[str], List[str], List[str]]:
"""
Return tuple with recursive lists of files, directories and unknown paths from given path.
It is possible to filter results by giving prefix and/or delimiter parameters.

:param path: path from which tree will be built
:type path: str
:param prefix: if set paths will be added if start with prefix
:type prefix: str
:param delimiter: if set paths will be added if end with delimiter
:type delimiter: str
:return: tuple with list of files, dirs and unknown items
:rtype: Tuple[List[str], List[str], List[str]]
"""
conn = self.get_conn()
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]

def append_matching_path_callback(list_):
return (
lambda item: list_.append(item)
if self._is_path_match(item, prefix, delimiter)
else None
)

conn.walktree(
remotepath=path,
fcallback=append_matching_path_callback(files),
dcallback=append_matching_path_callback(dirs),
ucallback=append_matching_path_callback(unknowns),
recurse=True,
)

return files, dirs, unknowns
37 changes: 33 additions & 4 deletions tests/contrib/hooks/test_sftp_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

TMP_PATH = '/tmp'
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'

SFTP_CONNECTION_USER = "root"
Expand All @@ -51,9 +52,12 @@ def update_connection(self, login, session=None):
def setUp(self):
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
self.hook = SFTPHook()
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))

with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')

def test_get_conn(self):
output = self.hook.get_conn()
Expand All @@ -72,7 +76,7 @@ def test_describe_directory(self):
def test_list_directory(self):
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [])
self.assertEqual(output, [SUB_DIR])

def test_create_and_delete_directory(self):
new_dir_name = 'new_dir'
Expand Down Expand Up @@ -116,7 +120,7 @@ def test_store_retrieve_and_delete_file(self):
)
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [TMP_FILE_FOR_TESTS])
self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
retrieved_file_name = 'retrieved.txt'
self.hook.retrieve_file(
remote_full_path=os.path.join(
Expand All @@ -129,7 +133,7 @@ def test_store_retrieve_and_delete_file(self):
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [])
self.assertEqual(output, [SUB_DIR])

def test_get_mod_time(self):
self.hook.store_file(
Expand Down Expand Up @@ -208,6 +212,31 @@ def test_path_exists(self, path, exists):
result = self.hook.path_exists(path)
self.assertEqual(result, exists)

@parameterized.expand([
("test/path/file.bin", None, None, True),
("test/path/file.bin", "test", None, True),
("test/path/file.bin", "test/", None, True),
("test/path/file.bin", None, "bin", True),
("test/path/file.bin", "test", "bin", True),
("test/path/file.bin", "test/", "file.bin", True),
("test/path/file.bin", None, "file.bin", True),
("test/path/file.bin", "diff", None, False),
("test/path/file.bin", "test//", None, False),
("test/path/file.bin", None, ".txt", False),
("test/path/file.bin", "diff", ".txt", False),
])
def test_path_match(self, path, prefix, delimiter, match):
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
self.assertEqual(result, match)

def test_get_tree_map(self):
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
files, dirs, unknowns = tree_map

self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
self.assertEqual(unknowns, [])

def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
Expand Down