/
ssh.py
276 lines (217 loc) · 9.25 KB
/
ssh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import errno
import logging
import os
import paramiko
from parsl.channels.base import Channel
from parsl.channels.errors import BadHostKeyException, AuthException, SSHException, BadScriptPath, BadPermsScriptPath, FileCopyException
from parsl.utils import RepresentationMixin
logger = logging.getLogger(__name__)
class NoAuthSSHClient(paramiko.SSHClient):
def _auth(self, username, *args):
self._transport.auth_none(username)
return
class SSHChannel(Channel, RepresentationMixin):
''' SSH persistent channel. This enables remote execution on sites
accessible via ssh. It is assumed that the user has setup host keys
so as to ssh to the remote host. Which goes to say that the following
test on the commandline should work:
>>> ssh <username>@<hostname>
'''
def __init__(self, hostname, username=None, password=None, script_dir=None, envs=None,
gssapi_auth=False, skip_auth=False, port=22, key_filename=None, host_keys_filename=None):
''' Initialize a persistent connection to the remote system.
We should know at this point whether ssh connectivity is possible
Args:
- hostname (String) : Hostname
KWargs:
- username (string) : Username on remote system
- password (string) : Password for remote system
- port : The port designated for the ssh connection. Default is 22.
- script_dir (string) : Full path to a script dir where
generated scripts could be sent to.
- envs (dict) : A dictionary of environment variables to be set when executing commands
- key_filename (string or list): the filename, or list of filenames, of optional private key(s)
Raises:
'''
self.hostname = hostname
self.username = username
self.password = password
self.port = port
self.script_dir = script_dir
self.skip_auth = skip_auth
self.gssapi_auth = gssapi_auth
self.key_filename = key_filename
self.host_keys_filename = host_keys_filename
if self.skip_auth:
self.ssh_client = NoAuthSSHClient()
else:
self.ssh_client = paramiko.SSHClient()
self.ssh_client.load_system_host_keys(filename=host_keys_filename)
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.sftp_client = None
self.envs = {}
if envs is not None:
self.envs = envs
def _is_connected(self):
transport = self.ssh_client.get_transport() if self.ssh_client else None
return transport and transport.is_active()
def _connect(self):
if not self._is_connected():
logger.debug(f"connecting to {self.hostname}:{self.port}")
try:
self.ssh_client.connect(
self.hostname,
username=self.username,
password=self.password,
port=self.port,
allow_agent=True,
gss_auth=self.gssapi_auth,
gss_kex=self.gssapi_auth,
key_filename=self.key_filename
)
transport = self.ssh_client.get_transport()
self.sftp_client = paramiko.SFTPClient.from_transport(transport)
except paramiko.BadHostKeyException as e:
raise BadHostKeyException(e, self.hostname)
except paramiko.AuthenticationException as e:
raise AuthException(e, self.hostname)
except paramiko.SSHException as e:
raise SSHException(e, self.hostname)
except Exception as e:
raise SSHException(e, self.hostname)
def _valid_sftp_client(self):
self._connect()
return self.sftp_client
def _valid_ssh_client(self):
self._connect()
return self.ssh_client
def prepend_envs(self, cmd, env={}):
env.update(self.envs)
if len(env.keys()) > 0:
env_vars = ' '.join(['{}={}'.format(key, value) for key, value in env.items()])
return 'env {0} {1}'.format(env_vars, cmd)
return cmd
def execute_wait(self, cmd, walltime=2, envs={}):
''' Synchronously execute a commandline string on the shell.
Args:
- cmd (string) : Commandline string to execute
- walltime (int) : walltime in seconds
Kwargs:
- envs (dict) : Dictionary of env variables
Returns:
- retcode : Return code from the execution, -1 on fail
- stdout : stdout string
- stderr : stderr string
Raises:
None.
'''
# Execute the command
stdin, stdout, stderr = self._valid_ssh_client().exec_command(
self.prepend_envs(cmd, envs), bufsize=-1, timeout=walltime
)
# Block on exit status from the command
exit_status = stdout.channel.recv_exit_status()
return exit_status, stdout.read().decode("utf-8"), stderr.read().decode("utf-8")
def push_file(self, local_source, remote_dir):
''' Transport a local file to a directory on a remote machine
Args:
- local_source (string): Path
- remote_dir (string): Remote path
Returns:
- str: Path to copied file on remote machine
Raises:
- BadScriptPath : if script path on the remote side is bad
- BadPermsScriptPath : You do not have perms to make the channel script dir
- FileCopyException : FileCopy failed.
'''
remote_dest = os.path.join(remote_dir, os.path.basename(local_source))
try:
self.makedirs(remote_dir, exist_ok=True)
except IOError as e:
logger.exception("Pushing {0} to {1} failed".format(local_source, remote_dir))
if e.errno == 2:
raise BadScriptPath(e, self.hostname)
elif e.errno == 13:
raise BadPermsScriptPath(e, self.hostname)
else:
logger.exception("File push failed due to SFTP client failure")
raise FileCopyException(e, self.hostname)
try:
self._valid_sftp_client().put(local_source, remote_dest, confirm=True)
# Set perm because some systems require the script to be executable
self._valid_sftp_client().chmod(remote_dest, 0o700)
except Exception as e:
logger.exception("File push from local source {} to remote destination {} failed".format(
local_source, remote_dest))
raise FileCopyException(e, self.hostname)
return remote_dest
def pull_file(self, remote_source, local_dir):
''' Transport file on the remote side to a local directory
Args:
- remote_source (string): remote_source
- local_dir (string): Local directory to copy to
Returns:
- str: Local path to file
Raises:
- FileExists : Name collision at local directory.
- FileCopyException : FileCopy failed.
'''
local_dest = os.path.join(local_dir, os.path.basename(remote_source))
try:
os.makedirs(local_dir)
except OSError as e:
if e.errno != errno.EEXIST:
logger.exception("Failed to create local_dir: {0}".format(local_dir))
raise BadScriptPath(e, self.hostname)
try:
self._valid_sftp_client().get(remote_source, local_dest)
except Exception as e:
logger.exception("File pull failed")
raise FileCopyException(e, self.hostname)
return local_dest
def close(self):
if self._is_connected():
return self.ssh_client.close()
def isdir(self, path):
"""Return true if the path refers to an existing directory.
Parameters
----------
path : str
Path of directory on the remote side to check.
"""
result = True
try:
self._valid_sftp_client().lstat(path)
except FileNotFoundError:
result = False
return result
def makedirs(self, path, mode=0o700, exist_ok=False):
"""Create a directory on the remote side.
If intermediate directories do not exist, they will be created.
Parameters
----------
path : str
Path of directory on the remote side to create.
mode : int
Permissions (posix-style) for the newly-created directory.
exist_ok : bool
If False, raise an OSError if the target directory already exists.
"""
if exist_ok is False and self.isdir(path):
raise OSError('Target directory {} already exists'.format(path))
self.execute_wait('mkdir -p {}'.format(path))
self._valid_sftp_client().chmod(path, mode)
def abspath(self, path):
"""Return the absolute path on the remote side.
Parameters
----------
path : str
Path for which the absolute path will be returned.
"""
return self._valid_sftp_client().normalize(path)
@property
def script_dir(self):
return self._script_dir
@script_dir.setter
def script_dir(self, value):
self._script_dir = value