diff --git a/pssh/ssh_client.py b/pssh/ssh_client.py index fc6fd9a3..3c22e338 100644 --- a/pssh/ssh_client.py +++ b/pssh/ssh_client.py @@ -239,7 +239,7 @@ def _make_sftp(self): transport.open_session() return paramiko.SFTPClient.from_transport(transport) - def mkdir(self, sftp, directory): + def _mkdir(self, sftp, directory): """Make directory via SFTP channel :param sftp: SFTP client object @@ -249,19 +249,41 @@ def mkdir(self, sftp, directory): Catches and logs at error level remote IOErrors on creating directory. """ - sub_dirs = [_dir for _dir in directory.split(os.path.sep) if _dir][:-1] - sub_dirs = os.path.sep + os.path.sep.join(sub_dirs) if directory.startswith(os.path.sep) \ - else os.path.sep.join(sub_dirs) - if sub_dirs: - try: - sftp.stat(sub_dirs) - except IOError: - return self.mkdir(sftp, sub_dirs) try: sftp.mkdir(directory) except IOError, error: logger.error("Error occured creating directory %s on %s - %s", directory, self.host, error) + logger.debug("Creating remote directory %s", directory) + return True + + def mkdir(self, sftp, directory): + """Make directory via SFTP channel. + + Parent paths in the directory are created if they do not exist. + + :param sftp: SFTP client object + :type sftp: :mod:`paramiko.SFTPClient` + :param directory: Remote directory to create + :type directory: str + + Catches and logs at error level remote IOErrors on creating directory. + """ + try: + parent_path, sub_dirs = directory.split(os.path.sep, 1) + except ValueError: + parent_path = directory.split(os.path.sep, 1)[0] + sub_dirs = None + if not parent_path and directory.startswith(os.path.sep): + parent_path, sub_dirs = sub_dirs.split(os.path.sep, 1) + try: + sftp.stat(parent_path) + except IOError: + self._mkdir(sftp, parent_path) + sftp.chdir(parent_path) + if sub_dirs: + return self.mkdir(sftp, sub_dirs) + return True def copy_file(self, local_file, remote_file): """Copy local file to host via SFTP/SCP @@ -276,14 +298,14 @@ def copy_file(self, local_file, remote_file): """ sftp = self._make_sftp() destination = [_dir for _dir in remote_file.split(os.path.sep) - if _dir][:-1] + if _dir][:-1][0] if remote_file.startswith(os.path.sep): - destination[0] = os.path.sep + destination[0] - # import ipdb; ipdb.set_trace() + destination = os.path.sep + destination try: sftp.stat(destination) except IOError: self.mkdir(sftp, destination) + sftp.chdir() try: sftp.put(local_file, remote_file) except Exception, error: diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py index b6e7b96a..6f5e5977 100644 --- a/tests/test_ssh_client.py +++ b/tests/test_ssh_client.py @@ -33,6 +33,7 @@ import os from test_pssh_client import USER_KEY import random, string +import shutil USER_KEY = paramiko.RSAKey.from_private_key_file( os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key'])) @@ -51,6 +52,61 @@ def setUp(self): def tearDown(self): del self.server del self.listen_socket + + def test_ssh_client_mkdir_recursive(self): + """Test SFTP mkdir of SSHClient""" + base_path = 'remote_test_dir1' + remote_dir = os.path.sep.join([base_path, + 'remote_test_dir2', + 'remote_test_dir3']) + try: + shutil.rmtree(base_path) + except OSError: + pass + client = SSHClient(self.host, port=self.listen_port, + pkey=self.user_key) + client.mkdir(client._make_sftp(), remote_dir) + self.assertTrue(os.path.isdir(remote_dir), + msg="SFTP recursive mkdir failed") + shutil.rmtree(base_path) + del client + + def test_ssh_client_mkdir_recursive_abspath(self): + """Test SFTP mkdir of SSHClient with absolute path + + Absolute SFTP paths resolve under the users' home directory, + not the root filesystem + """ + base_path = 'tmp' + remote_dir = os.path.sep.join([base_path, + 'remote_test_dir2', + 'remote_test_dir3']) + try: + shutil.rmtree(base_path) + except OSError: + pass + client = SSHClient(self.host, port=self.listen_port, + pkey=self.user_key) + client.mkdir(client._make_sftp(), '/' + remote_dir) + self.assertTrue(os.path.isdir(remote_dir), + msg="SFTP recursive mkdir failed") + shutil.rmtree(base_path) + del client + + def test_ssh_client_mkdir_single(self): + """Test SFTP mkdir of SSHClient""" + remote_dir = 'remote_test_dir1' + try: + shutil.rmtree(remote_dir) + except OSError: + pass + client = SSHClient(self.host, port=self.listen_port, + pkey=self.user_key) + client.mkdir(client._make_sftp(), remote_dir) + self.assertTrue(os.path.isdir(remote_dir), + msg="SFTP recursive mkdir failed") + shutil.rmtree(remote_dir) + del client def test_ssh_client_sftp(self): """Test SFTP features of SSHClient. Copy local filename to server,