diff --git a/pssh/pssh_client.py b/pssh/pssh_client.py index 84718a4c..bde85473 100644 --- a/pssh/pssh_client.py +++ b/pssh/pssh_client.py @@ -50,7 +50,7 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None, forward_ssh_agent=True, num_retries=DEFAULT_RETRIES, timeout=120, pool_size=10, proxy_host=None, proxy_port=22, - agent=None, host_config=None, channel_timeout=None): + agent=None, allow_agent=True, host_config=None, channel_timeout=None): """ :param hosts: Hosts to connect to :type hosts: list(str) @@ -96,7 +96,10 @@ def __init__(self, hosts, :param channel_timeout: (Optional) Time in seconds before an SSH operation \ times out. :type channel_timeout: int - + :type channel_timeout: int + :param allow_agent: (Optional) set to False to disable connecting to \ + the SSH agent + :type allow_agent: bool **Example Usage** >>> from pssh.pssh_client import ParallelSSHClient @@ -255,6 +258,7 @@ def __init__(self, hosts, # To hold host clients self.host_clients = {} self.agent = agent + self.allow_agent = allow_agent self.host_config = host_config if host_config else {} self.channel_timeout = channel_timeout @@ -474,7 +478,7 @@ def _exec_command(self, host, *args, **kwargs): timeout=self.timeout, proxy_host=self.proxy_host, proxy_port=self.proxy_port, - agent=self.agent, + allow_agent=self.allow_agent, agent=self.agent, channel_timeout=self.channel_timeout) return self.host_clients[host].exec_command(*args, **kwargs) diff --git a/pssh/ssh_client.py b/pssh/ssh_client.py index 8bc614b9..439141c8 100644 --- a/pssh/ssh_client.py +++ b/pssh/ssh_client.py @@ -41,8 +41,9 @@ class SSHClient(object): def __init__(self, host, user=None, password=None, port=None, pkey=None, forward_ssh_agent=True, - num_retries=DEFAULT_RETRIES, agent=None, timeout=10, - proxy_host=None, proxy_port=22, channel_timeout=None): + num_retries=DEFAULT_RETRIES, agent=None, + allow_agent=True, timeout=10, proxy_host=None, + proxy_port=22, channel_timeout=None): """Connect to host honouring any user set configuration in ~/.ssh/config \ or /etc/ssh/ssh_config @@ -74,6 +75,10 @@ def __init__(self, host, connecting to local SSH agent to lookup keys with our own SSH agent \ object. :type agent: :mod:`paramiko.agent.Agent` + :param forward_ssh_agent: (Optional) Turn on SSH agent forwarding - \ + equivalent to `ssh -A` from the `ssh` command line utility. \ + Defaults to True if not set. + :type forward_ssh_agent: bool :param proxy_host: (Optional) SSH host to tunnel connection through \ so that SSH clients connects to self.host via client -> proxy_host -> host :type proxy_host: str @@ -83,6 +88,9 @@ def __init__(self, host, :param channel_timeout: (Optional) Time in seconds before an SSH operation \ times out. :type channel_timeout: int + :param allow_agent: (Optional) set to False to disable connecting to \ + the SSH agent + :type allow_agent: bool """ ssh_config = paramiko.SSHConfig() _ssh_config_file = os.path.sep.join([os.path.expanduser('~'), @@ -107,6 +115,7 @@ def __init__(self, host, self.pkey = pkey self.port = port if port else 22 self.host = resolved_address + self.allow_agent = allow_agent if agent: self.client._agent = agent self.num_retries = num_retries @@ -158,7 +167,8 @@ def _connect(self, client, host, port, sock=None, retries=1): client.connect(host, username=self.user, password=self.password, port=port, pkey=self.pkey, - sock=sock, timeout=self.timeout) + sock=sock, timeout=self.timeout, + allow_agent=self.allow_agent) except sock_gaierror, ex: logger.error("Could not resolve host '%s' - retry %s/%s", self.host, retries, self.num_retries) diff --git a/tests/test_pssh_client.py b/tests/test_pssh_client.py index a9fb26f1..5a1e4613 100644 --- a/tests/test_pssh_client.py +++ b/tests/test_pssh_client.py @@ -690,3 +690,29 @@ def test_host_config(self): msg="Host config pkey override failed") for (server, _) in servers: server.kill() + + def test_pssh_client_override_allow_agent_authentication(self): + """Test running command with allow_agent set to False""" + client = ParallelSSHClient([self.host], port=self.listen_port, + pkey=self.user_key, allow_agent=False) + output = client.run_command(self.fake_cmd) + expected_exit_code = 0 + expected_stdout = [self.fake_resp] + expected_stderr = [] + + stdout = list(output[self.host]['stdout']) + stderr = list(output[self.host]['stderr']) + exit_code = output[self.host]['exit_code'] + self.assertEqual(expected_exit_code, exit_code, + msg="Got unexpected exit code - %s, expected %s" % + (exit_code, + expected_exit_code,)) + self.assertEqual(expected_stdout, stdout, + msg="Got unexpected stdout - %s, expected %s" % + (stdout, + expected_stdout,)) + self.assertEqual(expected_stderr, stderr, + msg="Got unexpected stderr - %s, expected %s" % + (stderr, + expected_stderr,)) + del client