diff --git a/pssh/ssh_client.py b/pssh/ssh_client.py index 3c22e338..ad271ffe 100644 --- a/pssh/ssh_client.py +++ b/pssh/ssh_client.py @@ -285,7 +285,16 @@ def mkdir(self, sftp, directory): return self.mkdir(sftp, sub_dirs) return True - def copy_file(self, local_file, remote_file): + def _copy_dir(self, local_dir, remote_dir): + """Call copy_file on every file in the specified directory, copying + them to the specified remote directory.""" + file_list = os.listdir(local_dir) + for file_name in file_list: + local_path = os.path.join(local_dir, file_name) + remote_path = os.path.join(remote_dir, file_name) + self.copy_file(local_path, remote_path, recurse=True) + + def copy_file(self, local_file, remote_file, recurse=False): """Copy local file to host via SFTP/SCP Copy is done natively using SFTP/SCP version 2 protocol, no scp command \ @@ -295,7 +304,17 @@ def copy_file(self, local_file, remote_file): :type local_file: str :param remote_file: Remote filepath on remote host to copy file to :type remote_file: str + :param recurse: Whether or not to descend into directories recursively. + :type recurse: bool + + :raises: :mod:'ValueError' when a directory is supplied to local_file \ + and recurse is not set """ + if os.path.isdir(local_file) and recurse: + return self._copy_dir(local_file, remote_file) + elif os.path.isdir(local_file) and not recurse: + raise ValueError("Recurse must be true if local_file is a " + "directory.") sftp = self._make_sftp() destination = [_dir for _dir in remote_file.split(os.path.sep) if _dir][:-1][0] diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py index 6f5e5977..7db46b87 100644 --- a/tests/test_ssh_client.py +++ b/tests/test_ssh_client.py @@ -23,6 +23,7 @@ import gevent import socket import time +import shutil import unittest from pssh import SSHClient, ParallelSSHClient, UnknownHostException, AuthenticationException,\ logger, ConnectionErrorException, UnknownHostException, SSHException @@ -141,6 +142,49 @@ def test_ssh_client_sftp(self): os.rmdir(dirpath) del client + def test_ssh_client_directory(self): + """Tests copying directories with SSH client. Copy all the files from + local directory to server, then make sure they are all present.""" + test_file_data = 'test' + local_test_path = 'directory_test' + remote_test_path = 'directory_test_copied' + os.mkdir(local_test_path) + remote_file_paths = [] + for i in range(0, 10): + local_file_path = os.path.join(local_test_path, 'foo' + str(i)) + remote_file_path = os.path.join(remote_test_path, 'foo' + str(i)) + remote_file_paths.append(remote_file_path) + test_file = open(local_file_path, 'w') + test_file.write(test_file_data) + test_file.close() + client = SSHClient(self.host, port=self.listen_port, + pkey=self.user_key) + client.copy_file(local_test_path, remote_test_path, recurse=True) + for path in remote_file_paths: + self.assertTrue(os.path.isfile(path)) + shutil.rmtree(local_test_path) + shutil.rmtree(remote_test_path) + + def test_ssh_client_directory_no_recurse(self): + """Tests copying directories with SSH client. Copy all the files from + local directory to server, then make sure they are all present.""" + test_file_data = 'test' + local_test_path = 'directory_test' + remote_test_path = 'directory_test_copied' + os.mkdir(local_test_path) + remote_file_paths = [] + for i in range(0, 10): + local_file_path = os.path.join(local_test_path, 'foo' + str(i)) + remote_file_path = os.path.join(remote_test_path, 'foo' + str(i)) + remote_file_paths.append(remote_file_path) + test_file = open(local_file_path, 'w') + test_file.write(test_file_data) + test_file.close() + client = SSHClient(self.host, port=self.listen_port, + pkey=self.user_key) + self.assertRaises(ValueError, client.copy_file, local_test_path, remote_test_path) + shutil.rmtree(local_test_path) + def test_ssh_agent_authentication(self): """Test authentication via SSH agent. Do not provide public key to use when creating SSHClient,