From cb3e1a599424fbafb4ceaa844cdcb330d58c6ed4 Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Tue, 15 Nov 2016 23:47:26 +0000 Subject: [PATCH 1/8] Merging with server refactoring changes --- embedded_server/embedded_server.py | 68 ++++++++++++++++++++---------- examples/parallel_commands.py | 4 +- pssh/ssh_client.py | 1 + 3 files changed, 49 insertions(+), 24 deletions(-) diff --git a/embedded_server/embedded_server.py b/embedded_server/embedded_server.py index 254ec784..cc03d098 100644 --- a/embedded_server/embedded_server.py +++ b/embedded_server/embedded_server.py @@ -30,7 +30,7 @@ Does _not_ support interactive shells, our clients do not use them. -Server private key is hardcoded. Server listen code inspired by demo_server.py in \ +Server private key is hardcoded. Server listen code inspired by demo_server.py in Paramiko repository. Server runs asynchronously in its own greenlet. Call `start_server` with a new `multiprocessing.Process` to run it on a new process with its own event loop. @@ -62,12 +62,34 @@ os.path.dirname(__file__), 'rsa.key'])) class Server(paramiko.ServerInterface): - def __init__(self, transport, fail_auth=False, + """Implements :mod:`paramiko.ServerInterface` to provide an + embedded SSH server implementation. + + Start a `Server` with at least a transport and a host key. + + Any SSH2 client with public key or password authentication + is allowed, only. Shell requests are not accepted. + + Implemented: + * Direct tcp-ip channels (tunneling) + * SSH Agent forwarding on request + * PTY requests + * Exec requests (run a command on server) + + Not Implemented: + * Shell requests + """ + + def __init__(self, transport, host_key, fail_auth=False, ssh_exception=False): self.event = Event() self.fail_auth = fail_auth self.ssh_exception = ssh_exception self.transport = transport + self.host_key = host_key + transport.load_server_moduli() + transport.add_server_key(self.host_key) + transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer) def check_channel_request(self, kind, chanid): return paramiko.OPEN_SUCCEEDED @@ -175,33 +197,29 @@ def listen(sock, fail_auth=False, ssh_exception=False, def _handle_ssh_connection(transport, fail_auth=False, ssh_exception=False): - try: - transport.load_server_moduli() - except: - return - transport.add_server_key(host_key) - transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer) - server = Server(transport, + server = Server(transport, HOST_KEY, fail_auth=fail_auth, ssh_exception=ssh_exception) + # server.run() try: transport.start_server(server=server) except paramiko.SSHException as e: logger.exception('SSH negotiation failed') - return except Exception: logger.exception("Error occured starting server") return - gevent.sleep(0) - channel = transport.accept(20) - if not channel: - logger.error("Could not establish channel") - return - while transport.is_active(): - logger.debug("Transport active, waiting..") - gevent.sleep(1) - while not channel.send_ready(): - gevent.sleep(.2) - channel.close() + while True: + gevent.sleep(0) + channel = transport.accept(20) + if not channel: + logger.error("Could not establish channel") + return + while transport.is_active(): + logger.debug("Transport active, waiting..") + gevent.sleep(1) + while not channel.send_ready(): + gevent.sleep(.2) + channel.close() + gevent.sleep(0) def handle_ssh_connection(sock, fail_auth=False, ssh_exception=False, @@ -227,8 +245,12 @@ def handle_ssh_connection(sock, def start_server(sock, fail_auth=False, ssh_exception=False, timeout=None): - return gevent.spawn(listen, sock, fail_auth=fail_auth, - timeout=timeout, ssh_exception=ssh_exception) + g = gevent.spawn(listen, sock, fail_auth=fail_auth, + timeout=timeout, ssh_exception=ssh_exception) + try: + g.join() + except KeyboardInterrupt: + sys.exit(0) if __name__ == "__main__": logging.basicConfig() diff --git a/examples/parallel_commands.py b/examples/parallel_commands.py index 98a76e58..509c1cd7 100644 --- a/examples/parallel_commands.py +++ b/examples/parallel_commands.py @@ -2,9 +2,10 @@ import datetime output = [] -host = 'localhost' +host = '192.168.1.2' hosts = [host] client = ParallelSSHClient(hosts) +import ipdb; ipdb.set_trace() # Run 10 five second sleeps cmds = ['sleep 5' for _ in xrange(10)] @@ -17,5 +18,6 @@ start = datetime.datetime.now() for _output in output: client.join(_output) + print(_output) end = datetime.datetime.now() print("All commands finished in %s" % (end-start,)) diff --git a/pssh/ssh_client.py b/pssh/ssh_client.py index fa653638..565087cb 100644 --- a/pssh/ssh_client.py +++ b/pssh/ssh_client.py @@ -120,6 +120,7 @@ def __init__(self, host, self.proxy_pkey = proxy_host, proxy_port, proxy_user, \ proxy_password, proxy_pkey self.proxy_client = None + import ipdb; ipdb.set_trace() if self.proxy_host and self.proxy_port: logger.debug("Proxy configured for destination host %s - Proxy host: %s:%s", self.host, self.proxy_host, self.proxy_port,) From 9efd70956c3e1d4f891508744317dccbc1772c5c Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Wed, 16 Nov 2016 00:00:43 +0000 Subject: [PATCH 2/8] Remove breakpoints --- examples/parallel_commands.py | 3 +-- pssh/ssh_client.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/parallel_commands.py b/examples/parallel_commands.py index 509c1cd7..353afef2 100644 --- a/examples/parallel_commands.py +++ b/examples/parallel_commands.py @@ -2,10 +2,9 @@ import datetime output = [] -host = '192.168.1.2' +host = 'localhost' hosts = [host] client = ParallelSSHClient(hosts) -import ipdb; ipdb.set_trace() # Run 10 five second sleeps cmds = ['sleep 5' for _ in xrange(10)] diff --git a/pssh/ssh_client.py b/pssh/ssh_client.py index 565087cb..fa653638 100644 --- a/pssh/ssh_client.py +++ b/pssh/ssh_client.py @@ -120,7 +120,6 @@ def __init__(self, host, self.proxy_pkey = proxy_host, proxy_port, proxy_user, \ proxy_password, proxy_pkey self.proxy_client = None - import ipdb; ipdb.set_trace() if self.proxy_host and self.proxy_port: logger.debug("Proxy configured for destination host %s - Proxy host: %s:%s", self.host, self.proxy_host, self.proxy_port,) From 37d8bd181955f149ea575374d43ebfcd63b1a17d Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Wed, 4 Jan 2017 18:07:31 +0000 Subject: [PATCH 3/8] Updated requirements --- requirements.txt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index afc628bb..b5c38ea5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -setuptools>=21.0 paramiko>=1.12,!=1.16.0 -gevent<=1.1; python_version < '3' -gevent>=1.1; python_version >= '3' +gevent<=1.1; python_version < '2.7' +gevent>=1.1; python_version >= '2.7' From 7c13c0f34b98bb660632d7346314ba91a14b9be5 Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Fri, 6 Jan 2017 18:17:01 +0000 Subject: [PATCH 4/8] Refactored embedded to be cleaner. Updated tests to use separate processes for embedded server --- embedded_server/embedded_server.py | 197 ++++++++++--------- tests/test_pssh_client.py | 303 ++++++++++++++--------------- 2 files changed, 247 insertions(+), 253 deletions(-) diff --git a/embedded_server/embedded_server.py b/embedded_server/embedded_server.py index cc03d098..0b2374cc 100644 --- a/embedded_server/embedded_server.py +++ b/embedded_server/embedded_server.py @@ -41,9 +41,11 @@ import sys if 'threading' in sys.modules: del sys.modules['threading'] -import gevent +from gevent import monkey +monkey.patch_all() +from multiprocessing import Process import os -import socket +import gevent from gevent import socket from gevent.event import Event import sys @@ -65,10 +67,10 @@ class Server(paramiko.ServerInterface): """Implements :mod:`paramiko.ServerInterface` to provide an embedded SSH server implementation. - Start a `Server` with at least a transport and a host key. + Start a `Server` with at least a host private key. Any SSH2 client with public key or password authentication - is allowed, only. Shell requests are not accepted. + is allowed, only. Interactive shell requests are not accepted. Implemented: * Direct tcp-ip channels (tunneling) @@ -77,19 +79,88 @@ class Server(paramiko.ServerInterface): * Exec requests (run a command on server) Not Implemented: - * Shell requests + * Interactive shell requests """ - def __init__(self, transport, host_key, fail_auth=False, - ssh_exception=False): + def __init__(self, host_key, fail_auth=False, + ssh_exception=False, + socket=None, + port=0, + listen_ip='127.0.0.1', + timeout=None): + if not socket: + self.socket = make_socket(listen_ip, port) + if not self.socket: + msg = "Could not establish listening connection on %s:%s" + logger.error(msg, listen_ip, port) + raise Exception(msg, listen_ip, port) + self.listen_ip = listen_ip + self.listen_port = self.socket.getsockname()[1] self.event = Event() self.fail_auth = fail_auth self.ssh_exception = ssh_exception - self.transport = transport self.host_key = host_key - transport.load_server_moduli() - transport.add_server_key(self.host_key) - transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer) + self.transport = None + self.timeout = timeout + + def start_listening(self): + try: + self.socket.listen(100) + logger.info('Listening for connection on %s:%s..', self.listen_ip, + self.listen_port) + except Exception as e: + logger.error('*** Listen failed: %s' % (str(e),)) + traceback.print_exc() + raise + conn, addr = self.socket.accept() + logger.info('Got connection..') + if self.timeout: + logger.debug("SSH server sleeping for %s then raising socket.timeout", + self.timeout) + gevent.Timeout(self.timeout).start() + self.transport = paramiko.Transport(conn) + self.transport.load_server_moduli() + self.transport.add_server_key(self.host_key) + self.transport.set_subsystem_handler('sftp', paramiko.SFTPServer, + StubSFTPServer) + try: + self.transport.start_server(server=self) + except paramiko.SSHException as e: + logger.exception('SSH negotiation failed') + raise + + def run(self): + while True: + try: + self.start_listening() + except Exception: + logger.exception("Error occured starting server") + continue + try: + self.accept_connections() + except Exception as e: + logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),)) + traceback.print_exc() + try: + self.transport.close() + except Exception: + pass + raise + + def accept_connections(self): + while True: + gevent.sleep(0) + channel = self.transport.accept(20) + if not channel: + logger.error("Could not establish channel") + return + while self.transport.is_active(): + logger.debug("Transport active, waiting..") + gevent.sleep(1) + while not channel.send_ready(): + gevent.sleep(.2) + channel.close() + gevent.sleep(0) def check_channel_request(self, kind, chanid): return paramiko.OPEN_SUCCEEDED @@ -157,7 +228,7 @@ def _read_response(self, channel, process): channel.send_exit_status(process.returncode) logger.debug("Command finished with return code %s", process.returncode) # Let clients consume output from channel before closing - gevent.sleep(.2) + gevent.sleep(.1) channel.close() def make_socket(listen_ip, port=0): @@ -172,92 +243,24 @@ def make_socket(listen_ip, port=0): return return sock -def listen(sock, fail_auth=False, ssh_exception=False, - timeout=None): - """Run server and given a cmd_to_run, send given - response to client connection. Returns (server, socket) tuple - where server is a joinable server thread and socket is listening - socket of server. - """ - listen_ip, listen_port = sock.getsockname() - if not sock: - logger.error("Could not establish listening connection on %s:%s", - listen_ip, listen_port) - return +def start_server(listen_ip, fail_auth=False, ssh_exception=False, + timeout=None, + listen_port=0): + server = Server(host_key, listen_ip=listen_ip, port=listen_port, + fail_auth=fail_auth, ssh_exception=ssh_exception, + timeout=timeout) try: - sock.listen(100) - logger.info('Listening for connection on %s:%s..', listen_ip, - listen_port) - except Exception as e: - logger.error('*** Listen failed: %s' % (str(e),)) - traceback.print_exc() - return - handle_ssh_connection(sock, fail_auth=fail_auth, - timeout=timeout, ssh_exception=ssh_exception) - -def _handle_ssh_connection(transport, fail_auth=False, - ssh_exception=False): - server = Server(transport, HOST_KEY, - fail_auth=fail_auth, ssh_exception=ssh_exception) - # server.run() - try: - transport.start_server(server=server) - except paramiko.SSHException as e: - logger.exception('SSH negotiation failed') - except Exception: - logger.exception("Error occured starting server") - return - while True: - gevent.sleep(0) - channel = transport.accept(20) - if not channel: - logger.error("Could not establish channel") - return - while transport.is_active(): - logger.debug("Transport active, waiting..") - gevent.sleep(1) - while not channel.send_ready(): - gevent.sleep(.2) - channel.close() - gevent.sleep(0) - -def handle_ssh_connection(sock, - fail_auth=False, ssh_exception=False, - timeout=None): - conn, addr = sock.accept() - logger.info('Got connection..') - if timeout: - logger.debug("SSH server sleeping for %s then raising socket.timeout", - timeout) - gevent.Timeout(timeout).start() - try: - transport = paramiko.Transport(conn) - _handle_ssh_connection(transport, fail_auth=fail_auth, - ssh_exception=ssh_exception) - except Exception as e: - logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),)) - traceback.print_exc() - try: - transport.close() - except: - pass - return - -def start_server(sock, fail_auth=False, ssh_exception=False, - timeout=None): - g = gevent.spawn(listen, sock, fail_auth=fail_auth, - timeout=timeout, ssh_exception=ssh_exception) - try: - g.join() + server.run() except KeyboardInterrupt: sys.exit(0) -if __name__ == "__main__": - logging.basicConfig() - logger.setLevel(logging.DEBUG) - sock = make_socket('127.0.0.1') - server = start_server(sock) - try: - server.join() - except KeyboardInterrupt: - sys.exit(0) +def start_server_process(listen_ip, fail_auth=False, ssh_exception=False, + timeout=None, listen_port=0): + server = Process(target=start_server, args=(listen_ip,), + kwargs={ + 'listen_port': listen_port, + 'fail_auth': fail_auth, + 'ssh_exception': ssh_exception, + 'timeout': timeout, + }) + return server diff --git a/tests/test_pssh_client.py b/tests/test_pssh_client.py index c2fadfc4..d2b3dbb4 100644 --- a/tests/test_pssh_client.py +++ b/tests/test_pssh_client.py @@ -27,6 +27,7 @@ import warnings import shutil import sys +from multiprocessing import Process import gevent from pssh import ParallelSSHClient, UnknownHostException, \ @@ -35,7 +36,7 @@ from pssh.exceptions import HostArgumentException from pssh.utils import load_private_key from embedded_server.embedded_server import start_server, make_socket, \ - logger as server_logger, paramiko_logger + logger as server_logger, paramiko_logger, start_server_process from embedded_server.fake_agent import FakeAgent from paramiko import RSAKey @@ -54,20 +55,26 @@ def setUp(self): self.long_cmd = lambda lines: 'for (( i=0; i<%s; i+=1 )) do echo $i; sleep 1; done' % (lines,) self.user_key = USER_KEY self.host = '127.0.0.1' - self.listen_socket = make_socket(self.host) - self.listen_port = self.listen_socket.getsockname()[1] - self.server = start_server(self.listen_socket) + self.listen_port = self.make_random_port() + self.server = start_server_process(self.host, + listen_port=self.listen_port) self.agent = FakeAgent() self.agent.add_key(USER_KEY) self.client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, agent=self.agent) - + + def make_random_port(self, host=None): + host = self.host if not host else host + listen_socket = make_socket(host) + listen_port = listen_socket.getsockname()[1] + del listen_socket + return listen_port + def tearDown(self): - del self.server - del self.listen_socket del self.client - + self.server.terminate() + def test_pssh_client_no_stdout_non_zero_exit_code_immediate_exit(self): output = self.client.run_command('exit 1') expected_exit_code = 1 @@ -79,10 +86,7 @@ def test_pssh_client_no_stdout_non_zero_exit_code_immediate_exit(self): expected_exit_code,)) def test_pssh_client_run_command_get_output(self): - client = ParallelSSHClient([self.host], port=self.listen_port, - pkey=self.user_key, - agent=self.agent) - output = client.run_command(self.fake_cmd) + output = self.client.run_command(self.fake_cmd) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] @@ -103,13 +107,11 @@ def test_pssh_client_run_command_get_output(self): expected_stderr,)) def test_pssh_client_run_command_get_output_explicit(self): - client = ParallelSSHClient([self.host], port=self.listen_port, - pkey=self.user_key) - out = client.run_command(self.fake_cmd) + out = self.client.run_command(self.fake_cmd) cmds = [cmd for host in out for cmd in [out[host]['cmd']]] output = {} for cmd in cmds: - client.get_output(cmd, output) + self.client.get_output(cmd, output) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] @@ -128,37 +130,32 @@ def test_pssh_client_run_command_get_output_explicit(self): msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) - del client def test_pssh_client_run_long_command(self): expected_lines = 5 - client = ParallelSSHClient([self.host], port=self.listen_port, - pkey=self.user_key) - output = client.run_command(self.long_cmd(expected_lines)) + output = self.client.run_command(self.long_cmd(expected_lines)) self.assertTrue(self.host in output, msg="Got no output for command") stdout = list(output[self.host]['stdout']) self.assertTrue(len(stdout) == expected_lines, msg="Expected %s lines of response, got %s" % ( expected_lines, len(stdout))) - del client def test_pssh_client_auth_failure(self): - self.server.kill() - server_socket = make_socket(self.host) - listen_port = server_socket.getsockname()[1] - server = start_server(server_socket, fail_auth=True) + listen_port = self.make_random_port() + server = start_server_process(self.host, listen_port=listen_port, + fail_auth=True) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, agent=self.agent) self.assertRaises(AuthenticationException, client.run_command, self.fake_cmd) del client + server.terminate() def test_pssh_client_hosts_list_part_failure(self): """Test getting output for remainder of host list in the case where one host in the host list has a failure""" - server2_socket = make_socket('127.0.0.2', port=self.listen_port) - server2_port = server2_socket.getsockname()[1] - server2 = start_server(server2_socket, fail_auth=True) + server2 = start_server_process('127.0.0.2', listen_port=self.listen_port, + fail_auth=True) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(hosts, port=self.listen_port, @@ -183,47 +180,45 @@ def test_pssh_client_hosts_list_part_failure(self): raise Exception("Expected AuthenticationException, got %s instead" % ( output[hosts[1]]['exception'],)) del client - server2.kill() - + server2.terminate() + def test_pssh_client_ssh_exception(self): - listen_socket = make_socket(self.host) - listen_port = listen_socket.getsockname()[1] - server = start_server(listen_socket, - ssh_exception=True) + listen_port = self.make_random_port() + server = start_server_process(self.host, + listen_port=listen_port, + ssh_exception=True) client = ParallelSSHClient([self.host], user='fakey', password='fakey', port=listen_port, pkey=RSAKey.generate(1024), + num_retries=1, ) self.assertRaises(SSHException, client.run_command, self.fake_cmd) del client - server.kill() + server.terminate() def test_pssh_client_timeout(self): - self.server.kill() - listen_socket = make_socket(self.host) - listen_port = listen_socket.getsockname()[1] + listen_port = self.make_random_port() server_timeout=0.2 client_timeout=server_timeout-0.1 - server = start_server(listen_socket, - timeout=server_timeout) + server = start_server_process(self.host, + listen_port=listen_port, + timeout=server_timeout) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, - timeout=client_timeout) - output = client.run_command(self.fake_cmd) + timeout=client_timeout, + num_retries=1) + output = client.run_command(self.fake_cmd, stop_on_errors=False) # Handle exception try: gevent.sleep(server_timeout+0.2) - client.pool.join() - if not server.exception: + client.join(output) + if not server.exitcode == 1: raise Exception( "Expected gevent.Timeout from socket timeout, got none") - raise server.exception - except gevent.Timeout: - pass finally: del client - server.kill() + server.terminate() def test_pssh_client_run_command_password(self): """Test password authentication. Embedded server accepts any password @@ -231,38 +226,34 @@ def test_pssh_client_run_command_password(self): client = ParallelSSHClient([self.host], port=self.listen_port, password='') output = client.run_command(self.fake_cmd) - client.join(output) + stdout = list(output[self.host]['stdout']) self.assertTrue(self.host in output, msg="No output for host") self.assertTrue(output[self.host]['exit_code'] == 0, msg="Expected exit code 0, got %s" % ( output[self.host]['exit_code'],)) - del client + self.assertEqual(stdout, [self.fake_resp]) def test_pssh_client_long_running_command_exit_codes(self): expected_lines = 5 - client = ParallelSSHClient([self.host], port=self.listen_port, - pkey=self.user_key) - output = client.run_command(self.long_cmd(expected_lines)) + output = self.client.run_command(self.long_cmd(expected_lines)) self.assertTrue(self.host in output, msg="Got no output for command") self.assertTrue(not output[self.host]['exit_code'], msg="Got exit code %s for still running cmd.." % ( output[self.host]['exit_code'],)) - self.assertFalse(client.finished(output)) + self.assertFalse(self.client.finished(output)) # Embedded server is also asynchronous and in the same thread # as our client so need to sleep for duration of server connection gevent.sleep(expected_lines) - client.join(output) - self.assertTrue(client.finished(output)) + self.client.join(output) + self.assertTrue(self.client.finished(output)) self.assertTrue(output[self.host]['exit_code'] == 0, msg="Got non-zero exit code %s" % ( output[self.host]['exit_code'],)) - del client def test_pssh_client_retries(self): """Test connection error retries""" - listen_socket = make_socket(self.host) - listen_port = listen_socket.getsockname()[1] + listen_port = self.make_random_port() expected_num_tries = 2 client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, @@ -273,17 +264,15 @@ def test_pssh_client_retries(self): except ConnectionErrorException, ex: num_tries = ex.args[-1:][0] self.assertEqual(expected_num_tries, num_tries, - msg="Got unexpected number of retries %s - expected %s" - % (num_tries, expected_num_tries,)) + msg="Got unexpected number of retries %s - " + "expected %s" % (num_tries, expected_num_tries,)) else: raise Exception('No ConnectionErrorException') def test_sftp_exceptions(self): - self.server.kill() - # Make socket with no server listening on it on separate ip + # Port with no server listening on it on separate ip host = '127.0.0.3' - _socket = make_socket(host) - port = _socket.getsockname()[1] + port = self.make_random_port(host=host) client = ParallelSSHClient([self.host], port=port, num_retries=1) cmds = client.copy_file("test", "test") client.pool.join() @@ -484,10 +473,9 @@ def test_pssh_hosts_more_than_pool_size(self): """Test we can successfully run on more hosts than our pool size and get logs for all hosts""" # Make a second server on the same port as the first one - server2_socket = make_socket('127.0.0.2', port=self.listen_port) - server2_port = server2_socket.getsockname()[1] - server2 = start_server(server2_socket) - hosts = [self.host, '127.0.0.2'] + host2 = '127.0.0.2' + server2 = start_server_process(host2, listen_port=self.listen_port) + hosts = [self.host, host2] client = ParallelSSHClient(hosts, port=self.listen_port, pkey=self.user_key, @@ -503,16 +491,13 @@ def test_pssh_hosts_more_than_pool_size(self): msg="Did not get expected output from all hosts. \ Got %s - expected %s" % (stdout, expected_stdout,)) del client - del server2 - + server2.terminate() + def test_pssh_hosts_iterator_hosts_modification(self): """Test using iterator as host list and modifying host list in place""" - server2_socket = make_socket('127.0.0.2', port=self.listen_port) - server2_port = server2_socket.getsockname()[1] - server2 = start_server(server2_socket) - server3_socket = make_socket('127.0.0.3', port=self.listen_port) - server3_port = server3_socket.getsockname()[1] - server3 = start_server(server3_socket) + host2, host3 = '127.0.0.2', '127.0.0.3' + server2 = start_server_process(host2, listen_port=self.listen_port) + server3 = start_server_process(host3, listen_port=self.listen_port) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(iter(hosts), port=self.listen_port, @@ -540,64 +525,71 @@ def test_pssh_hosts_iterator_hosts_modification(self): "%s/%s hosts" % (len(output), len(hosts),)) self.assertTrue(hosts[1] in output, msg="Did not get output for new host %s" % (hosts[1],)) - del client, server2, server3 - + del client + server2.terminate() + server3.terminate() + def test_ssh_proxy(self): """Test connecting to remote destination via SSH proxy client -> proxy -> destination Proxy SSH server accepts no commands and sends no responses, only proxies to destination. Destination accepts a command as usual.""" - proxy_server_socket = make_socket('127.0.0.2') - proxy_server_port = proxy_server_socket.getsockname()[1] - proxy_server = start_server(proxy_server_socket) - gevent.sleep(2) + proxy_host = '127.0.0.2' + proxy_server_port = self.make_random_port(proxy_host) + proxy_server = start_server_process(proxy_host, + listen_port=proxy_server_port) client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, - proxy_host='127.0.0.2', + proxy_host=proxy_host, proxy_port=proxy_server_port ) gevent.sleep(2) output = client.run_command(self.fake_cmd) + gevent.sleep(.2) stdout = list(output[self.host]['stdout']) expected_stdout = [self.fake_resp] self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) - self.server.kill() - proxy_server.kill() + del client + proxy_server.terminate() def test_ssh_proxy_auth(self): """Test connecting to remote destination via SSH proxy client -> proxy -> destination Proxy SSH server accepts no commands and sends no responses, only proxies to destination. Destination accepts a command as usual.""" - proxy_server_socket = make_socket('127.0.0.2') - proxy_server_port = proxy_server_socket.getsockname()[1] - proxy_server = start_server(proxy_server_socket) + host2 = '127.0.0.2' + proxy_server_port = self.make_random_port(host=host2) + proxy_server = start_server_process(host2, listen_port=proxy_server_port) proxy_user = 'proxy_user' proxy_password = 'fake' gevent.sleep(2) client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, - proxy_host='127.0.0.2', + proxy_host=host2, proxy_port=proxy_server_port, proxy_user=proxy_user, proxy_password='fake', proxy_pkey=self.user_key, + num_retries=1, ) gevent.sleep(2) - output = client.run_command(self.fake_cmd) - stdout = list(output[self.host]['stdout']) expected_stdout = [self.fake_resp] - self.assertEqual(expected_stdout, stdout, - msg="Got unexpected stdout - %s, expected %s" % ( - stdout, expected_stdout,)) - self.assertEqual(client.host_clients[self.host].proxy_user, proxy_user) - self.assertEqual(client.host_clients[self.host].proxy_password, proxy_password) - self.assertTrue(client.host_clients[self.host].proxy_pkey) - self.server.kill() - proxy_server.kill() + try: + output = client.run_command(self.fake_cmd) + stdout = list(output[self.host]['stdout']) + self.assertEqual(expected_stdout, stdout, + msg="Got unexpected stdout - %s, expected %s" % ( + stdout, expected_stdout,)) + self.assertEqual(client.host_clients[self.host].proxy_user, + proxy_user) + self.assertEqual(client.host_clients[self.host].proxy_password, + proxy_password) + self.assertTrue(client.host_clients[self.host].proxy_pkey) + finally: + proxy_server.terminate() def test_ssh_proxy_auth_fail(self): """Test failures while connecting via proxy""" @@ -636,32 +628,30 @@ def test_bash_variable_substitution(self): self.assertEqual(output, expected, msg="Unexpected output from bash variable substitution %s - should be %s" % ( output, expected,)) - + def test_identical_host_output(self): """Test that we get output when running with duplicated hosts""" - # Make socket with no server listening on it just for testing output - _socket1, _socket2 = make_socket(self.host), make_socket(self.host) - port = _socket1.getsockname()[1] + # Make port with no server listening on it just for testing output + port = self.make_random_port() hosts = [self.host, self.host, self.host] client = ParallelSSHClient(hosts, port=port, - pkey=self.user_key) + pkey=self.user_key, + num_retries=1) output = client.run_command(self.fake_cmd, stop_on_errors=False) client.pool.join() self.assertEqual(len(hosts), len(output.keys()), msg="Host list contains %s identical hosts, only got output for %s" % ( len(hosts), len(output.keys()))) - del _socket1, _socket2 - + def test_connection_error_exception(self): """Test that we get connection error exception in output with correct arguments""" - self.server.kill() - # Make socket with no server listening on it on separate ip + # Make port with no server listening on it on separate ip host = '127.0.0.3' - _socket = make_socket(host) - port = _socket.getsockname()[1] + port = self.make_random_port(host=host) hosts = [host] client = ParallelSSHClient(hosts, port=port, - pkey=self.user_key) + pkey=self.user_key, + num_retries=1) output = client.run_command(self.fake_cmd, stop_on_errors=False) client.pool.join() self.assertTrue('exception' in output[host], @@ -678,18 +668,17 @@ def test_connection_error_exception(self): ex.args[2], port,)) else: raise Exception("Expected ConnectionErrorException") - del _socket - + def test_authentication_exception(self): """Test that we get authentication exception in output with correct arguments""" - self.server.kill() - _socket = make_socket(self.host) - port = _socket.getsockname()[1] - server = start_server(_socket, fail_auth=True) + port = self.make_random_port() + server = start_server_process(self.host, fail_auth=True, + listen_port=port) hosts = [self.host] client = ParallelSSHClient(hosts, port=port, pkey=self.user_key, - agent=self.agent) + agent=self.agent, + num_retries=1) output = client.run_command(self.fake_cmd, stop_on_errors=False) client.pool.join() self.assertTrue('exception' in output[self.host], @@ -706,22 +695,20 @@ def test_authentication_exception(self): ex.args[2], port,)) else: raise Exception("Expected AuthenticationException") - server.kill() - del _socket - + server.terminate() + def test_ssh_exception(self): """Test that we get ssh exception in output with correct arguments""" - self.server.kill() host = '127.0.0.10' - _socket = make_socket(host) - port = _socket.getsockname()[1] - server = start_server(_socket, ssh_exception=True) + port = self.make_random_port(host=host) + server = start_server_process(host, ssh_exception=True, + listen_port=port) hosts = [host] client = ParallelSSHClient(hosts, port=port, user='fakey', password='fakey', - pkey=RSAKey.generate(1024)) + pkey=RSAKey.generate(1024), + num_retries=1) output = client.run_command(self.fake_cmd, stop_on_errors=False) - gevent.sleep(1) client.pool.join() self.assertTrue('exception' in output[host], msg="Got no exception for host %s - expected connection error" % ( @@ -737,8 +724,7 @@ def test_ssh_exception(self): ex.args[2], port,)) else: raise Exception("Expected SSHException") - server.kill() - del _socket + server.terminate() def test_multiple_single_quotes_in_cmd(self): """Test that we can run a command with multiple single quotes""" @@ -794,14 +780,14 @@ def test_host_config(self): user = 'overriden_user' password = 'overriden_pass' for host in hosts: - _socket = make_socket(host) - port = _socket.getsockname()[1] + port = self.make_random_port(host=host) host_config[host] = {} host_config[host]['port'] = port host_config[host]['user'] = user host_config[host]['password'] = password - server = start_server(_socket, fail_auth=hosts.index(host)) - servers.append((server, port)) + server = start_server_process(host, fail_auth=hosts.index(host), + listen_port=port) + servers.append(server) pkey_data = load_private_key(PKEY_FILENAME) host_config[hosts[0]]['private_key'] = pkey_data client = ParallelSSHClient(hosts, host_config=host_config) @@ -824,14 +810,12 @@ def test_host_config(self): msg="Host config password override failed") self.assertTrue(client.host_clients[hosts[0]].pkey == pkey_data, msg="Host config pkey override failed") - for (server, _) in servers: - server.kill() + for server in servers: + server.terminate() def test_pssh_client_override_allow_agent_authentication(self): """Test running command with allow_agent set to False""" - client = ParallelSSHClient([self.host], port=self.listen_port, - pkey=self.user_key, allow_agent=False) - output = client.run_command(self.fake_cmd) + output = self.client.run_command(self.fake_cmd) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] @@ -850,24 +834,21 @@ def test_pssh_client_override_allow_agent_authentication(self): msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) - del client def test_get_exit_codes_bad_output(self): self.assertFalse(self.client.get_exit_codes({})) self.assertFalse(self.client.get_exit_code({})) def test_per_host_tuple_args(self): - server2_socket = make_socket('127.0.0.2', port=self.listen_port) - server2_port = server2_socket.getsockname()[1] - server2 = start_server(server2_socket) - server3_socket = make_socket('127.0.0.3', port=self.listen_port) - server3_port = server3_socket.getsockname()[1] - server3 = start_server(server3_socket) - hosts = [self.host, '127.0.0.2', '127.0.0.3'] + host2, host3 = '127.0.0.2', '127.0.0.3' + server2 = start_server_process(host2, listen_port=self.listen_port) + server3 = start_server_process(host3, listen_port=self.listen_port) + hosts = [self.host, host2, host3] host_args = ('arg1', 'arg2', 'arg3') cmd = 'echo %s' client = ParallelSSHClient(hosts, port=self.listen_port, - pkey=self.user_key) + pkey=self.user_key, + num_retries=1) output = client.run_command(cmd, host_args=host_args) for i, host in enumerate(hosts): expected = [host_args[i]] @@ -884,22 +865,25 @@ def test_per_host_tuple_args(self): self.assertTrue(output[host]['exit_code'] == 0) self.assertRaises(HostArgumentException, client.run_command, cmd, host_args=[host_args[0]]) + # Invalid number of args + host_args = (('arg1', ),) + self.assertRaises(TypeError, client.run_command, cmd, host_args=host_args) + for server in [server2, server3]: + server.terminate() def test_per_host_dict_args(self): - server2_socket = make_socket('127.0.0.2', port=self.listen_port) - server2_port = server2_socket.getsockname()[1] - server2 = start_server(server2_socket) - server3_socket = make_socket('127.0.0.3', port=self.listen_port) - server3_port = server3_socket.getsockname()[1] - server3 = start_server(server3_socket) - hosts = [self.host, '127.0.0.2', '127.0.0.3'] + host2, host3 = '127.0.0.2', '127.0.0.3' + server2 = start_server_process(host2, listen_port=self.listen_port) + server3 = start_server_process(host3, listen_port=self.listen_port) + hosts = [self.host, host2, host3] hosts_gen = (h for h in hosts) host_args = [dict(zip(('host_arg1', 'host_arg2',), ('arg1-%s' % (i,), 'arg2-%s' % (i,),))) for i, _ in enumerate(hosts)] cmd = 'echo %(host_arg1)s %(host_arg2)s' client = ParallelSSHClient(hosts, port=self.listen_port, - pkey=self.user_key) + pkey=self.user_key, + num_retries=1) output = client.run_command(cmd, host_args=host_args) for i, host in enumerate(hosts): expected = ["%(host_arg1)s %(host_arg2)s" % host_args[i]] @@ -919,6 +903,13 @@ def test_per_host_dict_args(self): client.hosts = (h for h in hosts) self.assertRaises(HostArgumentException, client.run_command, cmd, host_args=[host_args[0]]) + client.hosts = hosts + + def test_per_host_dict_args_invalid(self): + cmd = 'echo %(host_arg1)s %(host_arg2)s' + # Invalid number of host args + host_args = [{'host_arg1': 'arg1'}] + self.assertRaises(KeyError, self.client.run_command, cmd, host_args=host_args) if __name__ == '__main__': unittest.main() From 483e085e9b8d90205239e32097ba7f171d74f890 Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Fri, 6 Jan 2017 18:18:42 +0000 Subject: [PATCH 5/8] WIP - try and fix event loop block on embedded server proxy connections. Added invalid host args when using dict --- embedded_server/embedded_server.py | 64 +++++++++++++++++++++++------- embedded_server/tunnel.py | 8 +++- pssh/pssh_client.py | 18 ++++++--- tests/test_pssh_client.py | 63 +++++++++++++++-------------- 4 files changed, 103 insertions(+), 50 deletions(-) diff --git a/embedded_server/embedded_server.py b/embedded_server/embedded_server.py index 0b2374cc..2fe5da96 100644 --- a/embedded_server/embedded_server.py +++ b/embedded_server/embedded_server.py @@ -38,12 +38,13 @@ *Warning* - Note that commands, with or without a shell, are actually run on the system running this server. Destructive commands will affect the system as permissions of user running the server allow. **Use at your own risk**. """ +from gipc import start_process +from multiprocessing import Process import sys if 'threading' in sys.modules: del sys.modules['threading'] from gevent import monkey monkey.patch_all() -from multiprocessing import Process import os import gevent from gevent import socket @@ -53,9 +54,11 @@ import logging import paramiko import time +import gevent.subprocess +import gevent.hub + from .stub_sftp import StubSFTPServer from .tunnel import Tunneler -import gevent.subprocess logger = logging.getLogger("embedded_server") paramiko_logger = logging.getLogger('paramiko.transport') @@ -86,16 +89,16 @@ def __init__(self, host_key, fail_auth=False, ssh_exception=False, socket=None, port=0, - listen_ip='127.0.0.1', + host='127.0.0.1', timeout=None): if not socket: - self.socket = make_socket(listen_ip, port) + self.socket = make_socket(host, port) if not self.socket: msg = "Could not establish listening connection on %s:%s" - logger.error(msg, listen_ip, port) - raise Exception(msg, listen_ip, port) - self.listen_ip = listen_ip - self.listen_port = self.socket.getsockname()[1] + logger.error(msg, host, port) + raise Exception(msg, host, port) + self.host = host + self.port = self.socket.getsockname()[1] self.event = Event() self.fail_auth = fail_auth self.ssh_exception = ssh_exception @@ -105,15 +108,20 @@ def __init__(self, host_key, fail_auth=False, def start_listening(self): try: + gevent.sleep(0) self.socket.listen(100) - logger.info('Listening for connection on %s:%s..', self.listen_ip, - self.listen_port) + logger.info('Listening for connection on %s:%s..', self.host, + self.port) except Exception as e: logger.error('*** Listen failed: %s' % (str(e),)) traceback.print_exc() raise + gevent.sleep() conn, addr = self.socket.accept() + gevent.sleep(.2) logger.info('Got connection..') + # import ipdb; ipdb.set_trace() + gevent.sleep(.2) if self.timeout: logger.debug("SSH server sleeping for %s then raising socket.timeout", self.timeout) @@ -123,21 +131,26 @@ def start_listening(self): self.transport.add_server_key(self.host_key) self.transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer) + gevent.sleep() try: self.transport.start_server(server=self) except paramiko.SSHException as e: logger.exception('SSH negotiation failed') raise + gevent.sleep(0) def run(self): while True: try: self.start_listening() + gevent.sleep(0) except Exception: logger.exception("Error occured starting server") continue + gevent.sleep(0) try: self.accept_connections() + gevent.sleep(0) except Exception as e: logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),)) traceback.print_exc() @@ -152,8 +165,10 @@ def accept_connections(self): gevent.sleep(0) channel = self.transport.accept(20) if not channel: - logger.error("Could not establish channel") - return + logger.error("Could not establish channel on %s:%s", + self.host, self.port) + gevent.sleep(0) + continue while self.transport.is_active(): logger.debug("Transport active, waiting..") gevent.sleep(1) @@ -190,21 +205,28 @@ def check_channel_pty_request(self, channel, term, width, height, pixelwidth, return True def check_channel_direct_tcpip_request(self, chanid, origin, destination): + # import ipdb; ipdb.set_trace() logger.debug("Proxy connection %s -> %s requested", origin, destination,) extra = {'username' : self.transport.get_username()} logger.debug("Starting proxy connection %s -> %s", origin, destination, extra=extra) + self.event.set() try: - tunnel = Tunneler(destination, self.transport, chanid) + gevent.sleep(.2) + tunnel = Process(target=Tunneler, args=(destination, self.transport, chanid,)) + tunnel.daemon = True tunnel.start() + gevent.sleep(.2) except Exception as ex: logger.error("Error creating proxy connection to %s - %s", destination, ex,) return paramiko.OPEN_FAILED_CONNECT_FAILED + gevent.sleep(2) return paramiko.OPEN_SUCCEEDED def check_channel_forward_agent_request(self, channel): logger.debug("Forward agent key request for channel %s" % (channel,)) + gevent.sleep(0) return True def check_channel_exec_request(self, channel, cmd, @@ -217,6 +239,7 @@ def check_channel_exec_request(self, channel, cmd, stdin=gevent.subprocess.PIPE, shell=True, env=_env) gevent.spawn(self._read_response, channel, process) + gevent.sleep(0) return True def _read_response(self, channel, process): @@ -230,6 +253,7 @@ def _read_response(self, channel, process): # Let clients consume output from channel before closing gevent.sleep(.1) channel.close() + gevent.sleep(0) def make_socket(listen_ip, port=0): """Make socket on given address and available port chosen by OS""" @@ -246,16 +270,26 @@ def make_socket(listen_ip, port=0): def start_server(listen_ip, fail_auth=False, ssh_exception=False, timeout=None, listen_port=0): - server = Server(host_key, listen_ip=listen_ip, port=listen_port, + # gevent.reinit() + gevent.hub.reinit() + # h.destroy(destroy_loop=True) + # h = gevent.hub.Hub() + # h.NOT_ERROR = (Exception,) + # gevent.hub.set_hub(h) + server = Server(host_key, host=listen_ip, port=listen_port, fail_auth=fail_auth, ssh_exception=ssh_exception, timeout=timeout) try: server.run() except KeyboardInterrupt: sys.exit(0) + # listen_process = Process(target=server.run) + # listen_process.start() + # listen_process.join() def start_server_process(listen_ip, fail_auth=False, ssh_exception=False, timeout=None, listen_port=0): + gevent.reinit() server = Process(target=start_server, args=(listen_ip,), kwargs={ 'listen_port': listen_port, @@ -263,4 +297,6 @@ def start_server_process(listen_ip, fail_auth=False, ssh_exception=False, 'ssh_exception': ssh_exception, 'timeout': timeout, }) + server.start() + gevent.sleep(.2) return server diff --git a/embedded_server/tunnel.py b/embedded_server/tunnel.py index 88ebfb80..5ec6efd2 100644 --- a/embedded_server/tunnel.py +++ b/embedded_server/tunnel.py @@ -29,9 +29,11 @@ class Tunneler(gevent.Greenlet): def __init__(self, address, transport, chanid): gevent.Greenlet.__init__(self) + gevent.sleep(.2) self.socket = socket.create_connection(address) self.transport = transport self.chanid = chanid + gevent.sleep(0) def close(self): try: @@ -50,13 +52,16 @@ def tunnel(self, dest_socket, source_chan): response_data = dest_socket.recv(1024) source_chan.sendall(response_data) logger.debug("Tunnel sent data..") - gevent.sleep(0) + gevent.sleep(.1) finally: source_chan.close() dest_socket.close() + gevent.sleep(0) def run(self): + gevent.sleep(.2) channel = self.transport.accept(20) + gevent.sleep(0) if not channel: return if not channel.get_id() == self.chanid: @@ -69,3 +74,4 @@ def run(self): except Exception as ex: logger.exception("Got exception creating tunnel - %s", ex,) logger.debug("Finished tunneling") + gevent.sleep(0) diff --git a/pssh/pssh_client.py b/pssh/pssh_client.py index 46a4737a..707cb68d 100644 --- a/pssh/pssh_client.py +++ b/pssh/pssh_client.py @@ -369,13 +369,19 @@ def run_command(self, *args, **kwargs): :rtype: Dictionary with host as key as per \ :mod:`pssh.pssh_client.ParallelSSHClient.get_output` - :raises: :mod:`pssh.exceptions.AuthenticationException` on authentication error - :raises: :mod:`pssh.exceptions.UnknownHostException` on DNS resolution error - :raises: :mod:`pssh.exceptions.ConnectionErrorException` on error connecting - :raises: :mod:`pssh.exceptions.SSHException` on other undefined SSH errors - :raises: :mod:`pssh.exceptions.HostArgumentException` on number of host \ - arguments not equal to number of hosts + :raises: :mod:`pssh.exceptions.AuthenticationException` on \ + authentication error + :raises: :mod:`pssh.exceptions.UnknownHostException` on DNS resolution \ + error + :raises: :mod:`pssh.exceptions.ConnectionErrorException` on error \ + connecting + :raises: :mod:`pssh.exceptions.SSHException` on other undefined SSH \ + errors + :raises: :mod:`pssh.exceptions.HostArgumentException` on number of \ + host arguments not equal to number of hosts :raises: `TypeError` on not enough host arguments for cmd string format + :raises: `KeyError` on no host argument key in arguments dict for cmd \ + string format **Example Usage** diff --git a/tests/test_pssh_client.py b/tests/test_pssh_client.py index d2b3dbb4..08a4547d 100644 --- a/tests/test_pssh_client.py +++ b/tests/test_pssh_client.py @@ -27,18 +27,17 @@ import warnings import shutil import sys -from multiprocessing import Process +from embedded_server.embedded_server import start_server, make_socket, \ + logger as server_logger, paramiko_logger, start_server_process +from embedded_server.fake_agent import FakeAgent +from paramiko import RSAKey import gevent from pssh import ParallelSSHClient, UnknownHostException, \ AuthenticationException, ConnectionErrorException, SSHException, \ logger as pssh_logger from pssh.exceptions import HostArgumentException from pssh.utils import load_private_key -from embedded_server.embedded_server import start_server, make_socket, \ - logger as server_logger, paramiko_logger, start_server_process -from embedded_server.fake_agent import FakeAgent -from paramiko import RSAKey PKEY_FILENAME = os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']) USER_KEY = RSAKey.from_private_key_file(PKEY_FILENAME) @@ -538,22 +537,24 @@ def test_ssh_proxy(self): proxy_server_port = self.make_random_port(proxy_host) proxy_server = start_server_process(proxy_host, listen_port=proxy_server_port) - client = ParallelSSHClient([self.host], port=self.listen_port, + client = ParallelSSHClient([self.host], port=39783, pkey=self.user_key, proxy_host=proxy_host, - proxy_port=proxy_server_port + proxy_port=proxy_server_port, ) gevent.sleep(2) - output = client.run_command(self.fake_cmd) - gevent.sleep(.2) - stdout = list(output[self.host]['stdout']) - expected_stdout = [self.fake_resp] - self.assertEqual(expected_stdout, stdout, - msg="Got unexpected stdout - %s, expected %s" % - (stdout, - expected_stdout,)) - del client - proxy_server.terminate() + try: + output = client.run_command(self.fake_cmd) + gevent.sleep(1) + stdout = list(output[self.host]['stdout']) + expected_stdout = [self.fake_resp] + self.assertEqual(expected_stdout, stdout, + msg="Got unexpected stdout - %s, expected %s" % + (stdout, + expected_stdout,)) + finally: + del client + proxy_server.terminate() def test_ssh_proxy_auth(self): """Test connecting to remote destination via SSH proxy @@ -589,18 +590,19 @@ def test_ssh_proxy_auth(self): proxy_password) self.assertTrue(client.host_clients[self.host].proxy_pkey) finally: + del client proxy_server.terminate() def test_ssh_proxy_auth_fail(self): """Test failures while connecting via proxy""" - listen_socket = make_socket(self.host) - listen_port = listen_socket.getsockname()[1] - self.server.kill() - server = start_server(listen_socket, - fail_auth=True) - proxy_server_socket = make_socket('127.0.0.2') - proxy_server_port = proxy_server_socket.getsockname()[1] - proxy_server = start_server(proxy_server_socket) + # listen_socket = make_socket(self.host) + proxy_host = '127.0.0.2' + listen_port = self.make_random_port() + server = start_server_process(self.host, listen_port=listen_port, + fail_auth=True) + proxy_server_port = self.make_random_port(host=proxy_host) + proxy_server = start_server_process(proxy_host, + listen_port=proxy_server_port) proxy_user = 'proxy_user' proxy_password = 'fake' gevent.sleep(2) @@ -611,12 +613,15 @@ def test_ssh_proxy_auth_fail(self): proxy_user=proxy_user, proxy_password='fake', proxy_pkey=self.user_key, + num_retries=1, ) gevent.sleep(2) - self.assertRaises(AuthenticationException, client.run_command, self.fake_cmd) - del client - server.kill() - proxy_server.kill() + try: + self.assertRaises(AuthenticationException, client.run_command, self.fake_cmd) + finally: + del client + server.terminate() + proxy_server.terminate() def test_bash_variable_substitution(self): """Test bash variables work correctly""" From f6c06a484f7aff22ca7c7da52181c40eabeeda40 Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Fri, 6 Jan 2017 18:25:04 +0000 Subject: [PATCH 6/8] Removed gipc import --- embedded_server/embedded_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embedded_server/embedded_server.py b/embedded_server/embedded_server.py index 2fe5da96..b346f835 100644 --- a/embedded_server/embedded_server.py +++ b/embedded_server/embedded_server.py @@ -38,7 +38,7 @@ *Warning* - Note that commands, with or without a shell, are actually run on the system running this server. Destructive commands will affect the system as permissions of user running the server allow. **Use at your own risk**. """ -from gipc import start_process +# from gipc import start_process from multiprocessing import Process import sys if 'threading' in sys.modules: From 51c711c5ee1cc59c969a4f6c8b4fa3483f263bde Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Tue, 10 Jan 2017 17:43:27 +0000 Subject: [PATCH 7/8] Cleaned up embedded server, tunnel and fixed race condition. Updated tests for embedded server changes --- embedded_server/embedded_server.py | 257 +++++++++++++---------------- embedded_server/tunnel.py | 9 +- tests/test_pssh_client.py | 123 ++++++-------- 3 files changed, 175 insertions(+), 214 deletions(-) diff --git a/embedded_server/embedded_server.py b/embedded_server/embedded_server.py index b346f835..61318ad2 100644 --- a/embedded_server/embedded_server.py +++ b/embedded_server/embedded_server.py @@ -23,28 +23,44 @@ Implements: * Execution of commands via exec_command * Public key and password auth - * Direct TCP tunneling + * Direct TCP tunneling (port forwarding) * SSH agent forwarding * Stub SFTP server from Paramiko * Forced authentication failure + * Forced server timeout for connection timeout simulation -Does _not_ support interactive shells, our clients do not use them. +Does _not_ support interactive shells - it is intended for purely API driven use. -Server private key is hardcoded. Server listen code inspired by demo_server.py in -Paramiko repository. +An embedded private key is provided as `embedded_server.host_key` and may be overriden -Server runs asynchronously in its own greenlet. Call `start_server` with a new `multiprocessing.Process` to run it on a new process with its own event loop. +Server runs asynchronously in its own greenlet. *Warning* - Note that commands, with or without a shell, are actually run on the system running this server. Destructive commands will affect the system as permissions of user running the server allow. **Use at your own risk**. + +Example Usage +=============== + +from embedded_server import start_server, start_server_from_ip, make_socket + +Make server from existing socket +---------------------------------- + +socket = make_socket('127.0.0.1') +server = start_server(socket) + +Make server from IP and optionally port +----------------------------------------- + +server, listen_port = start_server_from_ip('127.0.0.1') +other_server, _ = start_server_from_ip('127.0.0.1', port=1234) """ -# from gipc import start_process -from multiprocessing import Process import sys if 'threading' in sys.modules: del sys.modules['threading'] from gevent import monkey monkey.patch_all() + import os import gevent from gevent import socket @@ -52,8 +68,8 @@ import sys import traceback import logging -import paramiko import time +import paramiko import gevent.subprocess import gevent.hub @@ -68,9 +84,10 @@ class Server(paramiko.ServerInterface): """Implements :mod:`paramiko.ServerInterface` to provide an - embedded SSH server implementation. + embedded SSH2 server implementation. - Start a `Server` with at least a host private key. + Start a `Server` with at least a :mod:`paramiko.Transport` object + and a host private key. Any SSH2 client with public key or password authentication is allowed, only. Interactive shell requests are not accepted. @@ -85,97 +102,17 @@ class Server(paramiko.ServerInterface): * Interactive shell requests """ - def __init__(self, host_key, fail_auth=False, - ssh_exception=False, - socket=None, - port=0, - host='127.0.0.1', - timeout=None): - if not socket: - self.socket = make_socket(host, port) - if not self.socket: - msg = "Could not establish listening connection on %s:%s" - logger.error(msg, host, port) - raise Exception(msg, host, port) - self.host = host - self.port = self.socket.getsockname()[1] + def __init__(self, transport, host_key, fail_auth=False, + ssh_exception=False): + paramiko.ServerInterface.__init__(self) + transport.load_server_moduli() + transport.add_server_key(host_key) + transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer) + self.transport = transport self.event = Event() self.fail_auth = fail_auth self.ssh_exception = ssh_exception self.host_key = host_key - self.transport = None - self.timeout = timeout - - def start_listening(self): - try: - gevent.sleep(0) - self.socket.listen(100) - logger.info('Listening for connection on %s:%s..', self.host, - self.port) - except Exception as e: - logger.error('*** Listen failed: %s' % (str(e),)) - traceback.print_exc() - raise - gevent.sleep() - conn, addr = self.socket.accept() - gevent.sleep(.2) - logger.info('Got connection..') - # import ipdb; ipdb.set_trace() - gevent.sleep(.2) - if self.timeout: - logger.debug("SSH server sleeping for %s then raising socket.timeout", - self.timeout) - gevent.Timeout(self.timeout).start() - self.transport = paramiko.Transport(conn) - self.transport.load_server_moduli() - self.transport.add_server_key(self.host_key) - self.transport.set_subsystem_handler('sftp', paramiko.SFTPServer, - StubSFTPServer) - gevent.sleep() - try: - self.transport.start_server(server=self) - except paramiko.SSHException as e: - logger.exception('SSH negotiation failed') - raise - gevent.sleep(0) - - def run(self): - while True: - try: - self.start_listening() - gevent.sleep(0) - except Exception: - logger.exception("Error occured starting server") - continue - gevent.sleep(0) - try: - self.accept_connections() - gevent.sleep(0) - except Exception as e: - logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),)) - traceback.print_exc() - try: - self.transport.close() - except Exception: - pass - raise - - def accept_connections(self): - while True: - gevent.sleep(0) - channel = self.transport.accept(20) - if not channel: - logger.error("Could not establish channel on %s:%s", - self.host, self.port) - gevent.sleep(0) - continue - while self.transport.is_active(): - logger.debug("Transport active, waiting..") - gevent.sleep(1) - while not channel.send_ready(): - gevent.sleep(.2) - channel.close() - gevent.sleep(0) def check_channel_request(self, kind, chanid): return paramiko.OPEN_SUCCEEDED @@ -205,28 +142,25 @@ def check_channel_pty_request(self, channel, term, width, height, pixelwidth, return True def check_channel_direct_tcpip_request(self, chanid, origin, destination): - # import ipdb; ipdb.set_trace() logger.debug("Proxy connection %s -> %s requested", origin, destination,) extra = {'username' : self.transport.get_username()} logger.debug("Starting proxy connection %s -> %s", origin, destination, extra=extra) - self.event.set() try: - gevent.sleep(.2) - tunnel = Process(target=Tunneler, args=(destination, self.transport, chanid,)) - tunnel.daemon = True + tunnel = Tunneler(destination, self.transport, chanid) tunnel.start() - gevent.sleep(.2) except Exception as ex: logger.error("Error creating proxy connection to %s - %s", destination, ex,) return paramiko.OPEN_FAILED_CONNECT_FAILED - gevent.sleep(2) + self.event.set() + gevent.sleep() + logger.debug("Proxy connection started") return paramiko.OPEN_SUCCEEDED def check_channel_forward_agent_request(self, channel): logger.debug("Forward agent key request for channel %s" % (channel,)) - gevent.sleep(0) + gevent.sleep() return True def check_channel_exec_request(self, channel, cmd, @@ -261,42 +195,87 @@ def make_socket(listen_ip, port=0): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((listen_ip, port)) - except Exception as e: - logger.error('Failed to bind to address - %s' % (str(e),)) + except Exception as ex: + logger.error('Failed to bind to address - %s', ex) traceback.print_exc() return return sock -def start_server(listen_ip, fail_auth=False, ssh_exception=False, - timeout=None, - listen_port=0): - # gevent.reinit() - gevent.hub.reinit() - # h.destroy(destroy_loop=True) - # h = gevent.hub.Hub() - # h.NOT_ERROR = (Exception,) - # gevent.hub.set_hub(h) - server = Server(host_key, host=listen_ip, port=listen_port, - fail_auth=fail_auth, ssh_exception=ssh_exception, - timeout=timeout) +def listen(sock, fail_auth=False, ssh_exception=False, + timeout=None): + """Run server and given a cmd_to_run, send given + response to client connection. Returns (server, socket) tuple + where server is a joinable server thread and socket is listening + socket of server. + """ + # sock = make_socket(ip, port=port) + try: + sock.listen(100) + except Exception as e: + logger.error('*** Listen failed: %s' % (str(e),)) + traceback.print_exc() + return + host, port = sock.getsockname() + logger.info('Listening for connection on %s:%s..', host, port) + return handle_ssh_connection(sock, fail_auth=fail_auth, + timeout=timeout, ssh_exception=ssh_exception) + +def _handle_ssh_connection(transport, fail_auth=False, + ssh_exception=False): + server = Server(transport, host_key, + fail_auth=fail_auth, ssh_exception=ssh_exception) try: - server.run() - except KeyboardInterrupt: - sys.exit(0) - # listen_process = Process(target=server.run) - # listen_process.start() - # listen_process.join() - -def start_server_process(listen_ip, fail_auth=False, ssh_exception=False, - timeout=None, listen_port=0): - gevent.reinit() - server = Process(target=start_server, args=(listen_ip,), - kwargs={ - 'listen_port': listen_port, - 'fail_auth': fail_auth, - 'ssh_exception': ssh_exception, - 'timeout': timeout, - }) - server.start() - gevent.sleep(.2) - return server + transport.start_server(server=server) + except paramiko.SSHException as e: + logger.exception('SSH negotiation failed') + return + except Exception: + logger.exception("Error occured starting server") + return + # *Important* Allow other greenlets to execute before establishing connection + # which may be handled by said other greenlets + gevent.sleep(.5) + channel = transport.accept(20) + if not channel: + logger.error("Could not establish channel") + return + while transport.is_active(): + logger.debug("Transport active, waiting..") + gevent.sleep(1) + while not channel.send_ready(): + gevent.sleep(.2) + channel.close() + +def handle_ssh_connection(sock, + fail_auth=False, ssh_exception=False, + timeout=None): + conn, addr = sock.accept() + logger.info('Got connection..') + if timeout: + logger.debug("SSH server sleeping for %s then raising socket.timeout", + timeout) + gevent.Timeout(timeout).start().get() + try: + transport = paramiko.Transport(conn) + return _handle_ssh_connection(transport, fail_auth=fail_auth, + ssh_exception=ssh_exception) + except Exception as e: + logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),)) + traceback.print_exc() + try: + transport.close() + except Exception: + pass + +def start_server(sock, fail_auth=False, ssh_exception=False, + timeout=None): + return gevent.spawn(listen, sock, fail_auth=fail_auth, + timeout=timeout, ssh_exception=ssh_exception) + +def start_server_from_ip(ip, port=0, + fail_auth=False, ssh_exception=False, + timeout=None): + server_sock = make_socket(ip, port=port) + server = start_server(server_sock, fail_auth=fail_auth, + ssh_exception=ssh_exception, timeout=timeout) + return server, server_sock.getsockname()[1] diff --git a/embedded_server/tunnel.py b/embedded_server/tunnel.py index 5ec6efd2..4f769e14 100644 --- a/embedded_server/tunnel.py +++ b/embedded_server/tunnel.py @@ -24,12 +24,14 @@ from gevent import socket, select import logging -logger = logging.getLogger("fake_server") +logger = logging.getLogger("embedded_server.tunnel") -class Tunneler(gevent.Greenlet): +class Tunneler(gevent.Greenlet): + def __init__(self, address, transport, chanid): gevent.Greenlet.__init__(self) gevent.sleep(.2) + logger.info("Tunneller creating connection -> %s", address) self.socket = socket.create_connection(address) self.transport = transport self.chanid = chanid @@ -59,9 +61,8 @@ def tunnel(self, dest_socket, source_chan): gevent.sleep(0) def run(self): - gevent.sleep(.2) + logger.info("Tunnel waiting for connection") channel = self.transport.accept(20) - gevent.sleep(0) if not channel: return if not channel.get_id() == self.chanid: diff --git a/tests/test_pssh_client.py b/tests/test_pssh_client.py index 08a4547d..e2fd9595 100644 --- a/tests/test_pssh_client.py +++ b/tests/test_pssh_client.py @@ -28,16 +28,16 @@ import shutil import sys -from embedded_server.embedded_server import start_server, make_socket, \ - logger as server_logger, paramiko_logger, start_server_process -from embedded_server.fake_agent import FakeAgent -from paramiko import RSAKey import gevent from pssh import ParallelSSHClient, UnknownHostException, \ AuthenticationException, ConnectionErrorException, SSHException, \ logger as pssh_logger from pssh.exceptions import HostArgumentException from pssh.utils import load_private_key +from embedded_server.embedded_server import start_server, make_socket, \ + logger as server_logger, paramiko_logger, start_server_from_ip +from embedded_server.fake_agent import FakeAgent +from paramiko import RSAKey PKEY_FILENAME = os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']) USER_KEY = RSAKey.from_private_key_file(PKEY_FILENAME) @@ -54,9 +54,9 @@ def setUp(self): self.long_cmd = lambda lines: 'for (( i=0; i<%s; i+=1 )) do echo $i; sleep 1; done' % (lines,) self.user_key = USER_KEY self.host = '127.0.0.1' - self.listen_port = self.make_random_port() - self.server = start_server_process(self.host, - listen_port=self.listen_port) + self.server_sock = make_socket(self.host) + self.listen_port = self.server_sock.getsockname()[1] + self.server = start_server(self.server_sock) self.agent = FakeAgent() self.agent.add_key(USER_KEY) self.client = ParallelSSHClient([self.host], port=self.listen_port, @@ -72,7 +72,7 @@ def make_random_port(self, host=None): def tearDown(self): del self.client - self.server.terminate() + self.server.kill() def test_pssh_client_no_stdout_non_zero_exit_code_immediate_exit(self): output = self.client.run_command('exit 1') @@ -140,21 +140,20 @@ def test_pssh_client_run_long_command(self): expected_lines, len(stdout))) def test_pssh_client_auth_failure(self): - listen_port = self.make_random_port() - server = start_server_process(self.host, listen_port=listen_port, - fail_auth=True) + server, listen_port = start_server_from_ip(self.host, + fail_auth=True) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, agent=self.agent) self.assertRaises(AuthenticationException, client.run_command, self.fake_cmd) del client - server.terminate() + server.kill() def test_pssh_client_hosts_list_part_failure(self): """Test getting output for remainder of host list in the case where one host in the host list has a failure""" - server2 = start_server_process('127.0.0.2', listen_port=self.listen_port, - fail_auth=True) + server2, _ = start_server_from_ip('127.0.0.2', port=self.listen_port, + fail_auth=True) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(hosts, port=self.listen_port, @@ -179,13 +178,11 @@ def test_pssh_client_hosts_list_part_failure(self): raise Exception("Expected AuthenticationException, got %s instead" % ( output[hosts[1]]['exception'],)) del client - server2.terminate() + server2.kill() def test_pssh_client_ssh_exception(self): - listen_port = self.make_random_port() - server = start_server_process(self.host, - listen_port=listen_port, - ssh_exception=True) + server, listen_port = start_server_from_ip(self.host, + ssh_exception=True) client = ParallelSSHClient([self.host], user='fakey', password='fakey', port=listen_port, @@ -194,15 +191,13 @@ def test_pssh_client_ssh_exception(self): ) self.assertRaises(SSHException, client.run_command, self.fake_cmd) del client - server.terminate() + server.kill() def test_pssh_client_timeout(self): - listen_port = self.make_random_port() server_timeout=0.2 client_timeout=server_timeout-0.1 - server = start_server_process(self.host, - listen_port=listen_port, - timeout=server_timeout) + server, listen_port = start_server_from_ip(self.host, + timeout=server_timeout) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, timeout=client_timeout, @@ -212,12 +207,12 @@ def test_pssh_client_timeout(self): try: gevent.sleep(server_timeout+0.2) client.join(output) - if not server.exitcode == 1: + if not server.exception: raise Exception( "Expected gevent.Timeout from socket timeout, got none") finally: del client - server.terminate() + server.kill() def test_pssh_client_run_command_password(self): """Test password authentication. Embedded server accepts any password @@ -473,7 +468,7 @@ def test_pssh_hosts_more_than_pool_size(self): get logs for all hosts""" # Make a second server on the same port as the first one host2 = '127.0.0.2' - server2 = start_server_process(host2, listen_port=self.listen_port) + server2, _ = start_server_from_ip(host2, port=self.listen_port) hosts = [self.host, host2] client = ParallelSSHClient(hosts, port=self.listen_port, @@ -490,13 +485,13 @@ def test_pssh_hosts_more_than_pool_size(self): msg="Did not get expected output from all hosts. \ Got %s - expected %s" % (stdout, expected_stdout,)) del client - server2.terminate() + server2.kill() def test_pssh_hosts_iterator_hosts_modification(self): """Test using iterator as host list and modifying host list in place""" host2, host3 = '127.0.0.2', '127.0.0.3' - server2 = start_server_process(host2, listen_port=self.listen_port) - server3 = start_server_process(host3, listen_port=self.listen_port) + server2, _ = start_server_from_ip(host2, port=self.listen_port) + server3, _ = start_server_from_ip(host3, port=self.listen_port) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(iter(hosts), port=self.listen_port, @@ -525,27 +520,28 @@ def test_pssh_hosts_iterator_hosts_modification(self): self.assertTrue(hosts[1] in output, msg="Did not get output for new host %s" % (hosts[1],)) del client - server2.terminate() - server3.terminate() + server2.kill() + server3.kill() def test_ssh_proxy(self): """Test connecting to remote destination via SSH proxy client -> proxy -> destination Proxy SSH server accepts no commands and sends no responses, only proxies to destination. Destination accepts a command as usual.""" + del self.client + self.client = None + self.server.kill() + server, _ = start_server_from_ip(self.host, port=self.listen_port) proxy_host = '127.0.0.2' proxy_server_port = self.make_random_port(proxy_host) - proxy_server = start_server_process(proxy_host, - listen_port=proxy_server_port) - client = ParallelSSHClient([self.host], port=39783, + proxy_server, proxy_server_port = start_server_from_ip(proxy_host) + client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, proxy_host=proxy_host, proxy_port=proxy_server_port, ) - gevent.sleep(2) try: output = client.run_command(self.fake_cmd) - gevent.sleep(1) stdout = list(output[self.host]['stdout']) expected_stdout = [self.fake_resp] self.assertEqual(expected_stdout, stdout, @@ -554,7 +550,8 @@ def test_ssh_proxy(self): expected_stdout,)) finally: del client - proxy_server.terminate() + server.kill() + proxy_server.kill() def test_ssh_proxy_auth(self): """Test connecting to remote destination via SSH proxy @@ -562,11 +559,9 @@ def test_ssh_proxy_auth(self): Proxy SSH server accepts no commands and sends no responses, only proxies to destination. Destination accepts a command as usual.""" host2 = '127.0.0.2' - proxy_server_port = self.make_random_port(host=host2) - proxy_server = start_server_process(host2, listen_port=proxy_server_port) + proxy_server, proxy_server_port = start_server_from_ip(host2) proxy_user = 'proxy_user' proxy_password = 'fake' - gevent.sleep(2) client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, proxy_host=host2, @@ -576,7 +571,6 @@ def test_ssh_proxy_auth(self): proxy_pkey=self.user_key, num_retries=1, ) - gevent.sleep(2) expected_stdout = [self.fake_resp] try: output = client.run_command(self.fake_cmd) @@ -591,21 +585,15 @@ def test_ssh_proxy_auth(self): self.assertTrue(client.host_clients[self.host].proxy_pkey) finally: del client - proxy_server.terminate() + proxy_server.kill() def test_ssh_proxy_auth_fail(self): """Test failures while connecting via proxy""" - # listen_socket = make_socket(self.host) proxy_host = '127.0.0.2' - listen_port = self.make_random_port() - server = start_server_process(self.host, listen_port=listen_port, - fail_auth=True) - proxy_server_port = self.make_random_port(host=proxy_host) - proxy_server = start_server_process(proxy_host, - listen_port=proxy_server_port) + server, listen_port = start_server_from_ip(self.host, fail_auth=True) + proxy_server, proxy_server_port = start_server_from_ip(proxy_host) proxy_user = 'proxy_user' proxy_password = 'fake' - gevent.sleep(2) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, proxy_host='127.0.0.2', @@ -615,13 +603,12 @@ def test_ssh_proxy_auth_fail(self): proxy_pkey=self.user_key, num_retries=1, ) - gevent.sleep(2) try: self.assertRaises(AuthenticationException, client.run_command, self.fake_cmd) finally: del client - server.terminate() - proxy_server.terminate() + server.kill() + proxy_server.kill() def test_bash_variable_substitution(self): """Test bash variables work correctly""" @@ -676,9 +663,7 @@ def test_connection_error_exception(self): def test_authentication_exception(self): """Test that we get authentication exception in output with correct arguments""" - port = self.make_random_port() - server = start_server_process(self.host, fail_auth=True, - listen_port=port) + server, port = start_server_from_ip(self.host, fail_auth=True) hosts = [self.host] client = ParallelSSHClient(hosts, port=port, pkey=self.user_key, @@ -700,14 +685,12 @@ def test_authentication_exception(self): ex.args[2], port,)) else: raise Exception("Expected AuthenticationException") - server.terminate() + server.kill() def test_ssh_exception(self): """Test that we get ssh exception in output with correct arguments""" host = '127.0.0.10' - port = self.make_random_port(host=host) - server = start_server_process(host, ssh_exception=True, - listen_port=port) + server, port = start_server_from_ip(host, ssh_exception=True) hosts = [host] client = ParallelSSHClient(hosts, port=port, user='fakey', password='fakey', @@ -729,7 +712,7 @@ def test_ssh_exception(self): ex.args[2], port,)) else: raise Exception("Expected SSHException") - server.terminate() + server.kill() def test_multiple_single_quotes_in_cmd(self): """Test that we can run a command with multiple single quotes""" @@ -785,13 +768,11 @@ def test_host_config(self): user = 'overriden_user' password = 'overriden_pass' for host in hosts: - port = self.make_random_port(host=host) + server, port = start_server_from_ip(host, fail_auth=hosts.index(host)) host_config[host] = {} host_config[host]['port'] = port host_config[host]['user'] = user host_config[host]['password'] = password - server = start_server_process(host, fail_auth=hosts.index(host), - listen_port=port) servers.append(server) pkey_data = load_private_key(PKEY_FILENAME) host_config[hosts[0]]['private_key'] = pkey_data @@ -816,7 +797,7 @@ def test_host_config(self): self.assertTrue(client.host_clients[hosts[0]].pkey == pkey_data, msg="Host config pkey override failed") for server in servers: - server.terminate() + server.kill() def test_pssh_client_override_allow_agent_authentication(self): """Test running command with allow_agent set to False""" @@ -846,8 +827,8 @@ def test_get_exit_codes_bad_output(self): def test_per_host_tuple_args(self): host2, host3 = '127.0.0.2', '127.0.0.3' - server2 = start_server_process(host2, listen_port=self.listen_port) - server3 = start_server_process(host3, listen_port=self.listen_port) + server2, _ = start_server_from_ip(host2, port=self.listen_port) + server3, _ = start_server_from_ip(host3, port=self.listen_port) hosts = [self.host, host2, host3] host_args = ('arg1', 'arg2', 'arg3') cmd = 'echo %s' @@ -874,12 +855,12 @@ def test_per_host_tuple_args(self): host_args = (('arg1', ),) self.assertRaises(TypeError, client.run_command, cmd, host_args=host_args) for server in [server2, server3]: - server.terminate() + server.kill() def test_per_host_dict_args(self): host2, host3 = '127.0.0.2', '127.0.0.3' - server2 = start_server_process(host2, listen_port=self.listen_port) - server3 = start_server_process(host3, listen_port=self.listen_port) + server2, _ = start_server_from_ip(host2, port=self.listen_port) + server3, _ = start_server_from_ip(host3, port=self.listen_port) hosts = [self.host, host2, host3] hosts_gen = (h for h in hosts) host_args = [dict(zip(('host_arg1', 'host_arg2',), From 5c55e41a7bfff5e348867804774acf73b9f6bad0 Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Tue, 10 Jan 2017 17:44:25 +0000 Subject: [PATCH 8/8] Updated tests to use pssh.agent for agent tests. Removed fake agent from embedded server --- embedded_server/fake_agent.py | 43 ----------------------------------- pssh/agent.py | 1 + tests/test_pssh_client.py | 6 ++--- tests/test_ssh_client.py | 7 ++++-- 4 files changed, 9 insertions(+), 48 deletions(-) delete mode 100644 embedded_server/fake_agent.py diff --git a/embedded_server/fake_agent.py b/embedded_server/fake_agent.py deleted file mode 100644 index ab22fa31..00000000 --- a/embedded_server/fake_agent.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (C) 2015 Panos Kittenis - -# This library is free software; you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public -# License as published by the Free Software Foundation, version 2.1. - -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. - -# You should have received a copy of the GNU Lesser General Public -# License along with this library; if not, write to the Free Software -# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - -""" -Fake SSH Agent for testing ssh agent forwarding and agent based key -authentication -""" - -import paramiko.agent - -class FakeAgent(paramiko.agent.AgentSSH): - - def __init__(self): - self._conn = None - self.keys = [] - - def add_key(self, key): - """Add key to agent. - :param key: Key to add - :type key: :mod:`paramiko.pkey.PKey` - """ - self.keys.append(key) - - def _connect(self, conn): - pass - - def _close(self): - self._keys = [] - - def get_keys(self): - return tuple(self.keys) diff --git a/pssh/agent.py b/pssh/agent.py index 44c590da..869b12f9 100644 --- a/pssh/agent.py +++ b/pssh/agent.py @@ -15,6 +15,7 @@ # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +import paramiko.agent class SSHAgent(paramiko.agent.AgentSSH): """:mod:`paramiko.agent.AgentSSH` compatible class for programmatically diff --git a/tests/test_pssh_client.py b/tests/test_pssh_client.py index e2fd9595..ae891b49 100644 --- a/tests/test_pssh_client.py +++ b/tests/test_pssh_client.py @@ -36,7 +36,7 @@ from pssh.utils import load_private_key from embedded_server.embedded_server import start_server, make_socket, \ logger as server_logger, paramiko_logger, start_server_from_ip -from embedded_server.fake_agent import FakeAgent +from pssh.agent import SSHAgent from paramiko import RSAKey PKEY_FILENAME = os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']) @@ -57,7 +57,7 @@ def setUp(self): self.server_sock = make_socket(self.host) self.listen_port = self.server_sock.getsockname()[1] self.server = start_server(self.server_sock) - self.agent = FakeAgent() + self.agent = SSHAgent() self.agent.add_key(USER_KEY) self.client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, @@ -73,6 +73,7 @@ def make_random_port(self, host=None): def tearDown(self): del self.client self.server.kill() + del self.agent def test_pssh_client_no_stdout_non_zero_exit_code_immediate_exit(self): output = self.client.run_command('exit 1') @@ -533,7 +534,6 @@ def test_ssh_proxy(self): self.server.kill() server, _ = start_server_from_ip(self.host, port=self.listen_port) proxy_host = '127.0.0.2' - proxy_server_port = self.make_random_port(proxy_host) proxy_server, proxy_server_port = start_server_from_ip(proxy_host) client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py index ef966ef4..34dddeeb 100644 --- a/tests/test_ssh_client.py +++ b/tests/test_ssh_client.py @@ -30,7 +30,7 @@ logger, ConnectionErrorException, UnknownHostException, SSHException, utils from embedded_server.embedded_server import start_server, make_socket, logger as server_logger, \ paramiko_logger -from embedded_server.fake_agent import FakeAgent +from pssh.agent import SSHAgent import paramiko import os from test_pssh_client import USER_KEY @@ -250,7 +250,7 @@ def test_ssh_agent_authentication(self): instead override the client's agent with our own fake SSH agent, add our to key to agent and try to login to server. Key should be automatically picked up from the overriden agent""" - agent = FakeAgent() + agent = SSHAgent() agent.add_key(USER_KEY) client = SSHClient(self.host, port=self.listen_port, agent=agent) @@ -261,6 +261,9 @@ def test_ssh_agent_authentication(self): self.assertEqual(expected, output, msg = "Got unexpected command output - %s" % (output,)) del client + agent._connect(None) + agent._close() + del agent def test_ssh_client_conn_failure(self): """Test connection error failure case - ConnectionErrorException"""