From 059cb62478efd17742cb9f9acd975a54ef8b8a64 Mon Sep 17 00:00:00 2001 From: Dan <22e889d8@opayq.com> Date: Fri, 2 Sep 2016 17:25:34 +0100 Subject: [PATCH] Added proxy authentication parameters - resolves #18 --- pssh/pssh_client.py | 13 ++++++-- pssh/ssh_client.py | 29 +++++++++-------- tests/test_pssh_client.py | 65 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 89 insertions(+), 18 deletions(-) diff --git a/pssh/pssh_client.py b/pssh/pssh_client.py index ed94c4de..2405f180 100644 --- a/pssh/pssh_client.py +++ b/pssh/pssh_client.py @@ -49,7 +49,8 @@ class ParallelSSHClient(object): def __init__(self, hosts, user=None, password=None, port=None, pkey=None, forward_ssh_agent=True, num_retries=DEFAULT_RETRIES, timeout=120, - pool_size=10, proxy_host=None, proxy_port=22, + pool_size=10, proxy_host=None, proxy_port=22, proxy_user=None, + proxy_password=None, proxy_pkey=None, agent=None, allow_agent=True, host_config=None, channel_timeout=None): """ :param hosts: Hosts to connect to @@ -298,7 +299,9 @@ def __init__(self, hosts, self.pkey = pkey self.num_retries = num_retries self.timeout = timeout - self.proxy_host, self.proxy_port = proxy_host, proxy_port + self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password, \ + self.proxy_pkey = proxy_host, proxy_port, proxy_user, \ + proxy_password, proxy_pkey # To hold host clients self.host_clients = {} self.agent = agent @@ -555,7 +558,11 @@ def _exec_command(self, host, *args, **kwargs): timeout=self.timeout, proxy_host=self.proxy_host, proxy_port=self.proxy_port, - allow_agent=self.allow_agent, agent=self.agent, + proxy_user=self.proxy_user, + proxy_password=self.proxy_password, + proxy_pkey=self.proxy_pkey, + allow_agent=self.allow_agent, + agent=self.agent, channel_timeout=self.channel_timeout) return self.host_clients[host].exec_command(*args, **kwargs) diff --git a/pssh/ssh_client.py b/pssh/ssh_client.py index 28f307d3..26759f1a 100644 --- a/pssh/ssh_client.py +++ b/pssh/ssh_client.py @@ -44,7 +44,8 @@ def __init__(self, host, pkey=None, forward_ssh_agent=True, num_retries=DEFAULT_RETRIES, agent=None, allow_agent=True, timeout=10, proxy_host=None, - proxy_port=22, channel_timeout=None, + proxy_port=22, proxy_user=None, proxy_password=None, + proxy_pkey=None, channel_timeout=None, _openssh_config_file=None): """Connect to host honouring any user set configuration in ~/.ssh/config \ or /etc/ssh/ssh_config @@ -115,7 +116,9 @@ def __init__(self, host, self.num_retries = num_retries self.timeout = timeout self.channel_timeout = channel_timeout - self.proxy_host, self.proxy_port = proxy_host, proxy_port + self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password, \ + self.proxy_pkey = proxy_host, proxy_port, proxy_user, \ + proxy_password, proxy_pkey self.proxy_client = None if self.proxy_host and self.proxy_port: logger.debug("Proxy configured for destination host %s - Proxy host: %s:%s", @@ -134,13 +137,14 @@ def _connect_tunnel(self): """ self.proxy_client = paramiko.SSHClient() self.proxy_client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy()) - self._connect(self.proxy_client, self.proxy_host, self.proxy_port) + self._connect(self.proxy_client, self.proxy_host, self.proxy_port, + user=self.proxy_user, password=self.proxy_password, + pkey=self.proxy_pkey) logger.info("Connecting via SSH proxy %s:%s -> %s:%s", self.proxy_host, self.proxy_port, self.host, self.port,) try: - proxy_channel = self.proxy_client.get_transport().\ - open_channel('direct-tcpip', (self.host, self.port,), - ('127.0.0.1', 0)) + proxy_channel = self.proxy_client.get_transport().open_channel( + 'direct-tcpip', (self.host, self.port,), ('127.0.0.1', 0)) sleep(0) return self._connect(self.client, self.host, self.port, sock=proxy_channel) except ChannelException, ex: @@ -149,7 +153,8 @@ def _connect_tunnel(self): self.host, self.port, str(error_type)) - def _connect(self, client, host, port, sock=None, retries=1): + def _connect(self, client, host, port, sock=None, retries=1, + user=None, password=None, pkey=None): """Connect to host :raises: :mod:`pssh.exceptions.AuthenticationException` on authentication error @@ -158,20 +163,20 @@ def _connect(self, client, host, port, sock=None, retries=1): :raises: :mod:`pssh.exceptions.SSHException` on other undefined SSH errors """ try: - client.connect(host, username=self.user, - password=self.password, port=port, - pkey=self.pkey, + client.connect(host, username=user if user else self.user, + password=password if password else self.password, + port=port, pkey=pkey if pkey else self.pkey, sock=sock, timeout=self.timeout, allow_agent=self.allow_agent) except sock_gaierror, ex: logger.error("Could not resolve host '%s' - retry %s/%s", - self.host, retries, self.num_retries) + host, retries, self.num_retries) while retries < self.num_retries: sleep(5) return self._connect(client, host, port, sock=sock, retries=retries+1) raise UnknownHostException("Unknown host %s - %s - retry %s/%s", - self.host, str(ex.args[1]), retries, + host, str(ex.args[1]), retries, self.num_retries) except sock_error, ex: logger.error("Error connecting to host '%s:%s' - retry %s/%s", diff --git a/tests/test_pssh_client.py b/tests/test_pssh_client.py index ff479af5..ae1b9525 100644 --- a/tests/test_pssh_client.py +++ b/tests/test_pssh_client.py @@ -184,7 +184,7 @@ def test_pssh_client_auth_failure(self): except AuthenticationException: pass del client - server.join() + server.kill() def test_pssh_client_hosts_list_part_failure(self): """Test getting output for remainder of host list in the case where one @@ -230,7 +230,7 @@ def test_pssh_client_ssh_exception(self): ) self.assertRaises(SSHException, client.run_command, self.fake_cmd) del client - server.join() + server.kill() def test_pssh_client_timeout(self): listen_socket = make_socket(self.host) @@ -258,7 +258,7 @@ def test_pssh_client_timeout(self): # msg="Channel timeout %s does not match requested timeout %s" %( # chan_timeout, client_timeout,)) del client - server.join() + server.kill() def test_pssh_client_exec_command_password(self): """Test password authentication. Embedded server accepts any password @@ -500,6 +500,65 @@ def test_ssh_proxy(self): self.server.kill() proxy_server.kill() + def test_ssh_proxy_auth(self): + """Test connecting to remote destination via SSH proxy + client -> proxy -> destination + Proxy SSH server accepts no commands and sends no responses, only + proxies to destination. Destination accepts a command as usual.""" + proxy_server_socket = make_socket('127.0.0.2') + proxy_server_port = proxy_server_socket.getsockname()[1] + proxy_server = start_server(proxy_server_socket) + proxy_user = 'proxy_user' + proxy_password = 'fake' + gevent.sleep(2) + client = ParallelSSHClient([self.host], port=self.listen_port, + pkey=self.user_key, + proxy_host='127.0.0.2', + proxy_port=proxy_server_port, + proxy_user=proxy_user, + proxy_password='fake', + proxy_pkey=self.user_key, + ) + gevent.sleep(2) + output = client.run_command(self.fake_cmd) + stdout = list(output[self.host]['stdout']) + expected_stdout = [self.fake_resp] + self.assertEqual(expected_stdout, stdout, + msg="Got unexpected stdout - %s, expected %s" % ( + stdout, expected_stdout,)) + self.assertEqual(client.host_clients[self.host].proxy_user, proxy_user) + self.assertEqual(client.host_clients[self.host].proxy_password, proxy_password) + self.assertTrue(client.host_clients[self.host].proxy_pkey) + self.server.kill() + proxy_server.kill() + + def test_ssh_proxy_auth_fail(self): + """Test failures while connecting via proxy""" + listen_socket = make_socket(self.host) + listen_port = listen_socket.getsockname()[1] + self.server.kill() + server = start_server(listen_socket, + fail_auth=True) + proxy_server_socket = make_socket('127.0.0.2') + proxy_server_port = proxy_server_socket.getsockname()[1] + proxy_server = start_server(proxy_server_socket) + proxy_user = 'proxy_user' + proxy_password = 'fake' + gevent.sleep(2) + client = ParallelSSHClient([self.host], port=listen_port, + pkey=self.user_key, + proxy_host='127.0.0.2', + proxy_port=proxy_server_port, + proxy_user=proxy_user, + proxy_password='fake', + proxy_pkey=self.user_key, + ) + gevent.sleep(2) + self.assertRaises(AuthenticationException, client.run_command, self.fake_cmd) + del client + server.kill() + proxy_server.kill() + def test_bash_variable_substitution(self): """Test bash variables work correctly""" client = ParallelSSHClient([self.host], port=self.listen_port,