diff --git a/pssh/ssh_client.py b/pssh/ssh_client.py index c6e04f33..ffdeb773 100644 --- a/pssh/ssh_client.py +++ b/pssh/ssh_client.py @@ -290,8 +290,9 @@ def _mkdir(self, sftp, directory): try: sftp.mkdir(directory) except IOError, error: - logger.error("Error occured creating directory %s on %s - %s", - directory, self.host, error) + msg = "Error occured creating directory %s on %s - %s" + logger.error(msg, directory, self.host, error) + raise IOError(msg, directory, self.host, error) logger.debug("Creating remote directory %s", directory) return True @@ -372,6 +373,6 @@ def copy_file(self, local_file, remote_file, recurse=False): except Exception, error: logger.error("Error occured copying file %s to remote destination %s:%s - %s", local_file, self.host, remote_file, error) - else: - logger.info("Copied local file %s to remote destination %s:%s", - local_file, self.host, remote_file) + raise error + logger.info("Copied local file %s to remote destination %s:%s", + local_file, self.host, remote_file) diff --git a/requirements.txt b/requirements.txt index 7420037d..afc628bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ +setuptools>=21.0 paramiko>=1.12,!=1.16.0 -gevent>=1.1rc3 +gevent<=1.1; python_version < '3' +gevent>=1.1; python_version >= '3' diff --git a/setup.py b/setup.py index 451920fd..b434f16e 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,9 @@ packages=find_packages('.', exclude=( 'embedded_server', 'embedded_server.*')), install_requires=['paramiko', 'gevent'], + extras_require={':python_version < "3"': ['gevent<=1.1'], + ':python_version >= "3"': ['gevent>=1.1'], + }, classifiers=[ 'License :: OSI Approved :: GNU Lesser General Public License v2 (LGPLv2)', 'Intended Audience :: Developers', diff --git a/tests/test_pssh_client.py b/tests/test_pssh_client.py index 49894291..63fa8b8d 100644 --- a/tests/test_pssh_client.py +++ b/tests/test_pssh_client.py @@ -33,6 +33,8 @@ import os import warnings import shutil +import sys + PKEY_FILENAME = os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']) USER_KEY = paramiko.RSAKey.from_private_key_file(PKEY_FILENAME) @@ -390,6 +392,60 @@ def test_pssh_client_directory(self): self.assertTrue(os.path.isfile(path)) shutil.rmtree(local_test_path) shutil.rmtree(remote_test_path) + + def test_pssh_client_copy_file_failure(self): + """Test failure scenarios of file copy""" + test_file_data = 'test' + local_test_path = 'directory_test' + remote_test_path = 'directory_test_copied' + for path in [local_test_path, remote_test_path]: + mask = int('0700') if sys.version_info <= (2,) else 0o700 + if os.path.isdir(path): + os.chmod(path, mask) + for root, dirs, files in os.walk(path): + os.chmod(root, mask) + for _path in files + dirs: + os.chmod(os.path.join(root, _path), mask) + try: + shutil.rmtree(path) + except OSError: + pass + os.mkdir(local_test_path) + os.mkdir(remote_test_path) + local_file_path = os.path.join(local_test_path, 'test_file') + remote_file_path = os.path.join(remote_test_path, 'test_file') + test_file = open(local_file_path, 'w') + test_file.write('testing\n') + test_file.close() + # Permission errors on writing into dir + mask = 0111 if sys.version_info <= (2,) else 0o111 + os.chmod(remote_test_path, mask) + client = ParallelSSHClient([self.host], port=self.listen_port, + pkey=self.user_key) + cmds = client.copy_file(local_test_path, remote_test_path, recurse=True) + for cmd in cmds: + try: + cmd.get() + raise Exception("Expected IOError exception, got none") + except IOError: + pass + self.assertFalse(os.path.isfile(remote_file_path)) + # Create directory tree failure test + local_file_path = os.path.join(local_test_path, 'test_file') + remote_file_path = os.path.join(remote_test_path, 'test_dir', 'test_file') + cmds = client.copy_file(local_file_path, remote_file_path, recurse=True) + for cmd in cmds: + try: + cmd.get() + raise Exception("Expected IOError exception on creating remote " + "directory, got none") + except IOError: + pass + self.assertFalse(os.path.isfile(remote_file_path)) + mask = int('0600') if sys.version_info <= (2,) else 0o600 + os.chmod(remote_test_path, mask) + for path in [local_test_path, remote_test_path]: + shutil.rmtree(path) def test_pssh_pool_size(self): """Test setting pool size to non default values"""