Permalink
Browse files

Enable host_key checking at the strategy level

Implements a new method in the ssh connection plugin (fetch_and_store_key)
which is used to prefetch the key using ssh-keyscan.
  • Loading branch information...
jimi-c committed Dec 15, 2015
1 parent 15135f3 commit e5c2c03dea0998872a6b16a18d6c187685a5fc7a
@@ -32,6 +32,7 @@
from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task
from ansible.template import Templar
+from ansible.utils.connection import get_smart_connection_type
from ansible.utils.encrypt import key_for_hostname
from ansible.utils.listify import listify_lookup_plugin_terms
from ansible.utils.unicode import to_unicode
@@ -564,21 +565,7 @@ def _get_connection(self, variables, templar):
conn_type = self._play_context.connection
if conn_type == 'smart':
- conn_type = 'ssh'
- if sys.platform.startswith('darwin') and self._play_context.password:
- # due to a current bug in sshpass on OSX, which can trigger
- # a kernel panic even for non-privileged users, we revert to
- # paramiko on that OS when a SSH password is specified
- conn_type = "paramiko"
- else:
- # see if SSH can support ControlPersist if not use paramiko
- try:
- cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- (out, err) = cmd.communicate()
- if "Bad configuration option" in err or "Usage:" in err:
- conn_type = "paramiko"
- except OSError:
- conn_type = "paramiko"
+ conn_type = get_smart_connection_type(self._play_context)
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
if not connection:
@@ -57,6 +57,7 @@ def serialize(self):
name=self.name,
vars=self.vars.copy(),
address=self.address,
+ has_hostkey=self.has_hostkey,
uuid=self._uuid,
gathered_facts=self._gathered_facts,
groups=groups,
@@ -65,10 +66,11 @@ def serialize(self):
def deserialize(self, data):
self.__init__()
- self.name = data.get('name')
- self.vars = data.get('vars', dict())
- self.address = data.get('address', '')
- self._uuid = data.get('uuid', uuid.uuid4())
+ self.name = data.get('name')
+ self.vars = data.get('vars', dict())
+ self.address = data.get('address', '')
+ self.has_hostkey = data.get('has_hostkey', False)
+ self._uuid = data.get('uuid', uuid.uuid4())
groups = data.get('groups', [])
for group_data in groups:
@@ -89,6 +91,7 @@ def __init__(self, name=None, port=None):
self._gathered_facts = False
self._uuid = uuid.uuid4()
+ self.has_hostkey = False
def __repr__(self):
return self.get_name()
@@ -23,11 +23,11 @@
import fcntl
import gettext
import os
-from abc import ABCMeta, abstractmethod, abstractproperty
+from abc import ABCMeta, abstractmethod, abstractproperty
from functools import wraps
-from ansible.compat.six import with_metaclass
+from ansible.compat.six import with_metaclass
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.plugins import shell_loader
@@ -233,3 +233,4 @@ def connection_unlock(self):
f = self._play_context.connection_lockfd
fcntl.lockf(f, fcntl.LOCK_UN)
display.vvvv('CONNECTION: pid %d released lock on %d' % (os.getpid(), f))
+
@@ -19,7 +19,12 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
+from ansible.compat.six import text_type
+
+import base64
import fcntl
+import hmac
+import operator
import os
import pipes
import pty
@@ -28,9 +33,13 @@
import subprocess
import time
+from hashlib import md5, sha1, sha256
+
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
from ansible.plugins.connection import ConnectionBase
+from ansible.utils.boolean import boolean
+from ansible.utils.connection import get_smart_connection_type
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.utils.unicode import to_bytes, to_unicode
@@ -41,7 +50,128 @@
display = Display()
SSHPASS_AVAILABLE = None
+HASHED_KEY_MAGIC = "|1|"
+
+def split_args(argstring):
+ """
+ Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a
+ list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to
+ the argument list. The list will not contain any empty elements.
+ """
+ return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()]
+
+def get_ssh_opts(play_context):
+ # FIXME: caching may help here
+ opts_dict = dict()
+ try:
+ cmd = ['ssh', '-G', play_context.remote_addr]
+ res = subprocess.check_output(cmd)
+ for line in res.split('\n'):
+ if ' ' in line:
+ (key, val) = line.split(' ', 1)
+ else:
+ key = line
+ val = ''
+ opts_dict[key.lower()] = val
+
+ # next, we manually override any options that are being
+ # set via ssh_args or due to the fact that `ssh -G` doesn't
+ # actually use the options set via -o
+ for opt in ['ssh_args', 'ssh_common_args', 'ssh_extra_args']:
+ attr = getattr(play_context, opt, None)
+ if attr is not None:
+ args = split_args(attr)
+ for arg in args:
+ if '=' in arg:
+ (key, val) = arg.split('=', 1)
+ opts_dict[key.lower()] = val
+
+ return opts_dict
+ except subprocess.CalledProcessError:
+ return dict()
+
+def host_in_known_hosts(host, ssh_opts):
+ # the setting from the ssh_opts may actually be multiple files, so
+ # we use shlex.split and simply take the first one specified
+ user_host_file = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0])
+
+ host_file_list = []
+ host_file_list.append(user_host_file)
+ host_file_list.append("/etc/ssh/ssh_known_hosts")
+ host_file_list.append("/etc/ssh/ssh_known_hosts2")
+
+ hfiles_not_found = 0
+ for hf in host_file_list:
+ if not os.path.exists(hf):
+ continue
+ try:
+ host_fh = open(hf)
+ except (OSError, IOError) as e:
+ continue
+ else:
+ data = host_fh.read()
+ host_fh.close()
+
+ for line in data.split("\n"):
+ line = line.strip()
+ if line is None or " " not in line:
+ continue
+ tokens = line.split()
+ if not tokens:
+ continue
+ if tokens[0].find(HASHED_KEY_MAGIC) == 0:
+ # this is a hashed known host entry
+ try:
+ (kn_salt, kn_host) = tokens[0][len(HASHED_KEY_MAGIC):].split("|",2)
+ hash = hmac.new(kn_salt.decode('base64'), digestmod=sha1)
+ hash.update(host)
+ if hash.digest() == kn_host.decode('base64'):
+ return True
+ except:
+ # invalid hashed host key, skip it
+ continue
+ else:
+ # standard host file entry
+ if host in tokens[0]:
+ return True
+
+ return False
+
+def fetch_ssh_host_key(play_context, ssh_opts):
+ keyscan_cmd = ['ssh-keyscan']
+
+ if play_context.port:
+ keyscan_cmd.extend(['-p', text_type(play_context.port)])
+
+ if boolean(ssh_opts.get('hashknownhosts', 'no')):
+ keyscan_cmd.append('-H')
+ keyscan_cmd.append(play_context.remote_addr)
+
+ p = subprocess.Popen(keyscan_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
+ (stdout, stderr) = p.communicate()
+ if stdout == '':
+ raise AnsibleConnectionFailure("Failed to connect to the host to fetch the host key: %s." % stderr)
+ else:
+ return stdout
+
+def add_host_key(host_key, ssh_opts):
+ # the setting from the ssh_opts may actually be multiple files, so
+ # we use shlex.split and simply take the first one specified
+ user_known_hosts = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0])
+ user_ssh_dir = os.path.dirname(user_known_hosts)
+
+ if not os.path.exists(user_ssh_dir):
+ raise AnsibleError("the user ssh directory does not exist: %s" % user_ssh_dir)
+ elif not os.path.isdir(user_ssh_dir):
+ raise AnsibleError("%s is not a directory" % user_ssh_dir)
+
+ try:
+ display.vv("adding to known_hosts file: %s" % user_known_hosts)
+ with open(user_known_hosts, 'a') as f:
+ f.write(host_key)
+ except (OSError, IOError) as e:
+ raise AnsibleError("error when trying to access the known hosts file: '%s', error was: %s" % (user_known_hosts, text_type(e)))
class Connection(ConnectionBase):
''' ssh based connections '''
@@ -62,6 +192,56 @@ def __init__(self, *args, **kwargs):
def _connect(self):
return self
+ @staticmethod
+ def fetch_and_store_key(host, play_context):
+ ssh_opts = get_ssh_opts(play_context)
+ if not host_in_known_hosts(play_context.remote_addr, ssh_opts):
+ display.debug("host %s does not have a known host key, fetching it" % host)
+
+ # build the list of valid host key types, for use later as we scan for keys.
+ # we also use this to determine the most preferred key when multiple keys are available
+ valid_host_key_types = [x.lower() for x in ssh_opts.get('hostbasedkeytypes', '').split(',')]
+
+ # attempt to fetch the key with ssh-keyscan. More than one key may be
+ # returned, so we save all and use the above list to determine which
+ host_key_data = fetch_ssh_host_key(play_context, ssh_opts).strip().split('\n')
+ host_keys = dict()
+ for host_key in host_key_data:
+ (host_info, key_type, key_hash) = host_key.strip().split(' ', 3)
+ key_type = key_type.lower()
+ if key_type in valid_host_key_types and key_type not in host_keys:
+ host_keys[key_type.lower()] = host_key
+
+ if len(host_keys) == 0:
+ raise AnsibleConnectionFailure("none of the available host keys found were in the HostBasedKeyTypes configuration option")
+
+ # now we determine the preferred key by sorting the above dict on the
+ # index of the key type in the valid keys list
+ preferred_key = sorted(host_keys.items(), cmp=lambda x,y: cmp(valid_host_key_types.index(x), valid_host_key_types.index(y)), key=operator.itemgetter(0))[0]
+
+ # shamelessly copied from here:
+ # https://github.com/ojarva/python-sshpubkeys/blob/master/sshpubkeys/__init__.py#L39
+ # (which shamelessly copied it from somewhere else...)
+ (host_info, key_type, key_hash) = preferred_key[1].strip().split(' ', 3)
+ decoded_key = key_hash.decode('base64')
+ fp_plain = md5(decoded_key).hexdigest()
+ key_data = ':'.join(a+b for a, b in zip(fp_plain[::2], fp_plain[1::2]))
+
+ # prompt the user to add the key
+ # if yes, add it, otherwise raise AnsibleConnectionFailure
+ display.display("\nThe authenticity of host %s (%s) can't be established." % (host.name, play_context.remote_addr))
+ display.display("%s key fingerprint is SHA256:%s." % (key_type.upper(), sha256(decoded_key).digest().encode('base64').strip()))
+ display.display("%s key fingerprint is MD5:%s." % (key_type.upper(), key_data))
+ response = display.prompt("Are you sure you want to continue connecting (yes/no)? ")
+ display.display("")
+ if boolean(response):
+ add_host_key(host_key, ssh_opts)
+ return True
+ else:
+ raise AnsibleConnectionFailure("Host key validation failed.")
+
+ return False
+
@staticmethod
def _sshpass_available():
global SSHPASS_AVAILABLE
@@ -100,15 +280,6 @@ def _persistence_controls(command):
return controlpersist, controlpath
- @staticmethod
- def _split_args(argstring):
- """
- Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a
- list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to
- the argument list. The list will not contain any empty elements.
- """
- return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()]
-
def _add_args(self, explanation, args):
"""
Adds the given args to self._command and displays a caller-supplied
@@ -157,7 +328,7 @@ def _build_command(self, binary, *other_args):
# Next, we add [ssh_connection]ssh_args from ansible.cfg.
if self._play_context.ssh_args:
- args = self._split_args(self._play_context.ssh_args)
+ args = split_args(self._play_context.ssh_args)
self._add_args("ansible.cfg set ssh_args", args)
# Now we add various arguments controlled by configuration file settings
@@ -210,7 +381,7 @@ def _build_command(self, binary, *other_args):
for opt in ['ssh_common_args', binary + '_extra_args']:
attr = getattr(self._play_context, opt, None)
if attr is not None:
- args = self._split_args(attr)
+ args = split_args(attr)
self._add_args("PlayContext set %s" % opt, args)
# Check if ControlPersist is enabled and add a ControlPath if one hasn't
@@ -29,7 +29,7 @@
from jinja2.exceptions import UndefinedError
from ansible import constants as C
-from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
+from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure
from ansible.executor.play_iterator import PlayIterator
from ansible.executor.process.worker import WorkerProcess
from ansible.executor.task_result import TaskResult
@@ -39,6 +39,7 @@
from ansible.playbook.included_file import IncludedFile
from ansible.plugins import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader
from ansible.template import Templar
+from ansible.utils.connection import get_smart_connection_type
from ansible.vars.unsafe_proxy import wrap_var
try:
@@ -139,6 +140,33 @@ def _queue_task(self, host, task, task_vars, play_context):
display.debug("entering _queue_task() for %s/%s" % (host, task))
+ if C.HOST_KEY_CHECKING and not host.has_hostkey:
+ # caveat here, regarding with loops. It is assumed that none of the connection
+ # related variables would contain '{{item}}' as it would cause some really
+ # weird loops. As is, if someone did something odd like that they would need
+ # to disable host key checking
+ templar = Templar(loader=self._loader, variables=task_vars)
+ temp_pc = play_context.set_task_and_variable_override(task=task, variables=task_vars, templar=templar)
+ temp_pc.post_validate(templar)
+ if temp_pc.connection in ('smart', 'ssh') and get_smart_connection_type(temp_pc) == 'ssh':
+ try:
+ # get the ssh connection plugin's class, and use its builtin
+ # static method to fetch and save the key to the known_hosts file
+ ssh_conn = connection_loader.get('ssh', class_only=True)
+ ssh_conn.fetch_and_store_key(host, temp_pc)
+ except AnsibleConnectionFailure as e:
+ # if that fails, add the host to the list of unreachable
+ # hosts and send the appropriate callback
+ self._tqm._unreachable_hosts[host.name] = True
+ self._tqm._stats.increment('dark', host.name)
+ tr = TaskResult(host=host, task=task, return_data=dict(msg=text_type(e)))
+ self._tqm.send_callback('v2_runner_on_unreachable', tr)
+ return
+
+ # finally, we set the has_hostkey flag to true for this
+ # host so we can skip it quickly in the future
+ host.has_hostkey = True
+
task_vars['hostvars'] = self._tqm.hostvars
# and then queue the new task
display.debug("%s - putting task (%s) in queue" % (host, task))
Oops, something went wrong.

0 comments on commit e5c2c03

Please sign in to comment.