diff --git a/bin/ansible-playbook b/bin/ansible-playbook index 646b64a764f6c0..2d3d3480e00dec 100755 --- a/bin/ansible-playbook +++ b/bin/ansible-playbook @@ -62,6 +62,8 @@ def main(args): check_opts=True, diff_opts=True ) + #parser.add_option('--vault-password', dest="vault_password", + # help="password for vault encrypted files") parser.add_option('-e', '--extra-vars', dest="extra_vars", action="append", help="set additional variables as key=value or YAML/JSON", default=[]) parser.add_option('-t', '--tags', dest='tags', default='all', @@ -100,12 +102,13 @@ def main(args): su_pass = None if not options.listhosts and not options.syntax and not options.listtasks: options.ask_pass = options.ask_pass or C.DEFAULT_ASK_PASS + options.ask_vault_pass = options.ask_vault_pass or C.DEFAULT_ASK_VAULT_PASS # Never ask for an SSH password when we run with local connection if options.connection == "local": options.ask_pass = False options.ask_sudo_pass = options.ask_sudo_pass or C.DEFAULT_ASK_SUDO_PASS options.ask_su_pass = options.ask_su_pass or C.DEFAULT_ASK_SU_PASS - (sshpass, sudopass, su_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass) + (sshpass, sudopass, su_pass, vault_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass, ask_vault_pass=options.ask_vault_pass) options.sudo_user = options.sudo_user or C.DEFAULT_SUDO_USER options.su_user = options.su_user or C.DEFAULT_SU_USER @@ -170,7 +173,8 @@ def main(args): diff=options.diff, su=options.su, su_pass=su_pass, - su_user=options.su_user + su_user=options.su_user, + vault_password=vault_pass ) if options.listhosts or options.listtasks or options.syntax: diff --git a/bin/ansible-vault b/bin/ansible-vault new file mode 100755 index 00000000000000..9fd19bda2e982f --- /dev/null +++ b/bin/ansible-vault @@ -0,0 +1,187 @@ +#!/usr/bin/env python + +# (c) 2014, James Tanner +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +# ansible-pull is a script that runs ansible in local mode +# after checking out a playbooks directory from source repo. There is an +# example playbook to bootstrap this script in the examples/ dir which +# installs ansible and sets it up to run on cron. + +import sys +import traceback + +from ansible import utils +from ansible import errors +from ansible.utils.vault import * +from ansible.utils.vault import Vault + +from optparse import OptionParser + +#------------------------------------------------------------------------------------- +# Utility functions for parsing actions/options +#------------------------------------------------------------------------------------- + +VALID_ACTIONS = ("create", "decrypt", "edit", "encrypt", "rekey") + +def build_option_parser(action): + """ + Builds an option parser object based on the action + the user wants to execute. + """ + + usage = "usage: %%prog [%s] [--help] [options] file_name" % "|".join(VALID_ACTIONS) + epilog = "\nSee '%s --help' for more information on a specific command.\n\n" % os.path.basename(sys.argv[0]) + OptionParser.format_epilog = lambda self, formatter: self.epilog + parser = OptionParser(usage=usage, epilog=epilog) + + if not action: + parser.print_help() + sys.exit() + + # options for all actions + #parser.add_option('-p', '--password', help="encryption key") + #parser.add_option('-c', '--cipher', dest='cipher', default="AES", help="cipher to use") + parser.add_option('-d', '--debug', dest='debug', action="store_true", help="debug") + + # options specific to actions + if action == "create": + parser.set_usage("usage: %prog create [options] file_name") + elif action == "decrypt": + parser.set_usage("usage: %prog decrypt [options] file_name") + elif action == "edit": + parser.set_usage("usage: %prog edit [options] file_name") + elif action == "encrypt": + parser.set_usage("usage: %prog encrypt [options] file_name") + elif action == "rekey": + parser.set_usage("usage: %prog rekey [options] file_name") + + # done, return the parser + return parser + +def get_action(args): + """ + Get the action the user wants to execute from the + sys argv list. + """ + for i in range(0,len(args)): + arg = args[i] + if arg in VALID_ACTIONS: + del args[i] + return arg + return None + +def get_opt(options, k, defval=""): + """ + Returns an option from an Optparse values instance. + """ + try: + data = getattr(options, k) + except: + return defval + if k == "roles_path": + if os.pathsep in data: + data = data.split(os.pathsep)[0] + return data + +#------------------------------------------------------------------------------------- +# Command functions +#------------------------------------------------------------------------------------- + +def _get_vault(filename, options, password): + this_vault = Vault() + this_vault.filename = filename + this_vault.vault_password = password + this_vault.password = password + return this_vault + +def execute_create(args, options, parser): + + if len(args) > 1: + raise errors.AnsibleError("create does not accept more than one filename") + + password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True, confirm_vault=True) + + this_vault = _get_vault(args[0], options, password) + if not hasattr(options, 'cipher'): + this_vault.cipher = 'AES' + this_vault.create() + +def execute_decrypt(args, options, parser): + + password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True) + + for f in args: + this_vault = _get_vault(f, options, password) + this_vault.decrypt() + + print "Decryption successful" + +def execute_edit(args, options, parser): + + if len(args) > 1: + raise errors.AnsibleError("create does not accept more than one filename") + + password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True) + + for f in args: + this_vault = _get_vault(f, options, password) + this_vault.edit() + +def execute_encrypt(args, options, parser): + + password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True, confirm_vault=True) + + for f in args: + this_vault = _get_vault(f, options, password) + if not hasattr(options, 'cipher'): + this_vault.cipher = 'AES' + this_vault.encrypt() + + print "Encryption successful" + +def execute_rekey(args, options, parser): + + password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True, ask_new_vault_pass=True, confirm_new=True) + + for f in args: + this_vault = _get_vault(f, options, password) + this_vault.rekey(new_password) + + print "Rekey successful" + +#------------------------------------------------------------------------------------- +# MAIN +#------------------------------------------------------------------------------------- + +def main(): + + action = get_action(sys.argv) + parser = build_option_parser(action) + (options, args) = parser.parse_args() + + # execute the desired action + try: + fn = globals()["execute_%s" % action] + fn(args, options, parser) + except Exception, err: + if options.debug: + print traceback.format_exc() + print "ERROR:",err + sys.exit(1) + +if __name__ == "__main__": + main() + diff --git a/lib/ansible/constants.py b/lib/ansible/constants.py index dbe7ef2db519f5..b8dbc8e7d314f3 100644 --- a/lib/ansible/constants.py +++ b/lib/ansible/constants.py @@ -117,6 +117,7 @@ def shell_expand_path(path): DEFAULT_SUDO_USER = get_config(p, DEFAULTS, 'sudo_user', 'ANSIBLE_SUDO_USER', 'root') DEFAULT_ASK_SUDO_PASS = get_config(p, DEFAULTS, 'ask_sudo_pass', 'ANSIBLE_ASK_SUDO_PASS', False, boolean=True) DEFAULT_REMOTE_PORT = get_config(p, DEFAULTS, 'remote_port', 'ANSIBLE_REMOTE_PORT', None, integer=True) +DEFAULT_ASK_VAULT_PASS = get_config(p, DEFAULTS, 'ask_vault_pass', 'ANSIBLE_ASK_VAULT_PASS', False, boolean=True) DEFAULT_TRANSPORT = get_config(p, DEFAULTS, 'transport', 'ANSIBLE_TRANSPORT', 'smart') DEFAULT_SCP_IF_SSH = get_config(p, 'ssh_connection', 'scp_if_ssh', 'ANSIBLE_SCP_IF_SSH', False, boolean=True) DEFAULT_MANAGED_STR = get_config(p, DEFAULTS, 'ansible_managed', None, 'Ansible managed: {file} modified on %Y-%m-%d %H:%M:%S by {uid} on {host}') @@ -172,4 +173,5 @@ def shell_expand_path(path): DEFAULT_REMOTE_PASS = None DEFAULT_SUBSET = None DEFAULT_SU_PASS = None - +VAULT_VERSION_MIN = 1.0 +VAULT_VERSION_MAX = 1.0 diff --git a/lib/ansible/inventory/__init__.py b/lib/ansible/inventory/__init__.py index ed945ce50c2013..8b8bbcceef261c 100644 --- a/lib/ansible/inventory/__init__.py +++ b/lib/ansible/inventory/__init__.py @@ -347,19 +347,19 @@ def _get_group_variables(self, groupname): raise Exception("group not found: %s" % groupname) return group.get_variables() - def get_variables(self, hostname): + def get_variables(self, hostname, vault_password=None): if hostname not in self._vars_per_host: - self._vars_per_host[hostname] = self._get_variables(hostname) + self._vars_per_host[hostname] = self._get_variables(hostname, vault_password=vault_password) return self._vars_per_host[hostname] - def _get_variables(self, hostname): + def _get_variables(self, hostname, vault_password=None): host = self.get_host(hostname) if host is None: raise errors.AnsibleError("host not found: %s" % hostname) vars = {} - vars_results = [ plugin.run(host) for plugin in self._vars_plugins ] + vars_results = [ plugin.run(host, vault_password=vault_password) for plugin in self._vars_plugins ] for updated in vars_results: if updated is not None: vars.update(updated) diff --git a/lib/ansible/inventory/vars_plugins/group_vars.py b/lib/ansible/inventory/vars_plugins/group_vars.py index 06a22169a7aaf5..3421565a5fb04c 100644 --- a/lib/ansible/inventory/vars_plugins/group_vars.py +++ b/lib/ansible/inventory/vars_plugins/group_vars.py @@ -23,7 +23,7 @@ from ansible import utils import ansible.constants as C -def _load_vars(basepath, results): +def _load_vars(basepath, results, vault_password=None): """ Load variables from any potential yaml filename combinations of basepath, returning result. @@ -35,7 +35,7 @@ def _load_vars(basepath, results): found_paths = [] for path in paths_to_check: - found, results = _load_vars_from_path(path, results) + found, results = _load_vars_from_path(path, results, vault_password=vault_password) if found: found_paths.append(path) @@ -49,7 +49,7 @@ def _load_vars(basepath, results): return results -def _load_vars_from_path(path, results): +def _load_vars_from_path(path, results, vault_password=None): """ Robustly access the file at path and load variables, carefully reporting errors in a friendly/informative way. @@ -90,7 +90,7 @@ def _load_vars_from_path(path, results): # regular file elif stat.S_ISREG(pathstat.st_mode): - data = utils.parse_yaml_from_file(path) + data = utils.parse_yaml_from_file(path, vault_password=vault_password) if type(data) != dict: raise errors.AnsibleError( "%s must be stored as a dictionary/hash" % path) @@ -143,7 +143,7 @@ def __init__(self, inventory): self.inventory = inventory - def run(self, host): + def run(self, host, vault_password=None): """ main body of the plugin, does actual loading """ @@ -183,11 +183,11 @@ def run(self, host): # load vars in dir/group_vars/name_of_group for group in groups: base_path = os.path.join(basedir, "group_vars/%s" % group) - results = _load_vars(base_path, results) + results = _load_vars(base_path, results, vault_password=vault_password) # same for hostvars in dir/host_vars/name_of_host base_path = os.path.join(basedir, "host_vars/%s" % host.name) - results = _load_vars(base_path, results) + results = _load_vars(base_path, results, vault_password=vault_password) # all done, results is a dictionary of variables for this particular host. return results diff --git a/lib/ansible/playbook/__init__.py b/lib/ansible/playbook/__init__.py index 9b19c59ac9fedd..65965526251de5 100644 --- a/lib/ansible/playbook/__init__.py +++ b/lib/ansible/playbook/__init__.py @@ -72,6 +72,7 @@ def __init__(self, su = False, su_user = False, su_pass = False, + vault_password = False, ): """ @@ -138,6 +139,7 @@ def __init__(self, self.su = su self.su_user = su_user self.su_pass = su_pass + self.vault_password = vault_password self.callbacks.playbook = self self.runner_callbacks.playbook = self @@ -172,7 +174,7 @@ def _load_playbook_from_file(self, path, vars={}): run top level error checking on playbooks and allow them to include other playbooks. ''' - playbook_data = utils.parse_yaml_from_file(path) + playbook_data = utils.parse_yaml_from_file(path, vault_password=self.vault_password) accumulated_plays = [] play_basedirs = [] @@ -242,7 +244,7 @@ def run(self): # loop through all patterns and run them self.callbacks.on_start() for (play_ds, play_basedir) in zip(self.playbook, self.play_basedirs): - play = Play(self, play_ds, play_basedir) + play = Play(self, play_ds, play_basedir, vault_password=self.vault_password) assert play is not None matched_tags, unmatched_tags = play.compare_tags(self.only_tags) @@ -352,6 +354,7 @@ def _run_task_internal(self, task): su=task.su, su_user=task.su_user, su_pass=task.su_pass, + vault_pass = self.vault_password, run_hosts=hosts, no_log=task.no_log, ) @@ -504,6 +507,7 @@ def _do_setup_step(self, play): su=play.su, su_user=play.su_user, su_pass=self.su_pass, + vault_pass=self.vault_password, transport=play.transport, is_playbook=True, module_vars=play.vars, @@ -569,9 +573,8 @@ def _run_play(self, play): self._do_setup_step(play) # now with that data, handle contentional variable file imports! - all_hosts = self._trim_unavailable_hosts(play._play_hosts) - play.update_vars_files(all_hosts) + play.update_vars_files(all_hosts, vault_password=self.vault_password) hosts_count = len(all_hosts) serialized_batch = [] diff --git a/lib/ansible/playbook/play.py b/lib/ansible/playbook/play.py index 94f4dc074ae153..d431ad48e5c120 100644 --- a/lib/ansible/playbook/play.py +++ b/lib/ansible/playbook/play.py @@ -34,7 +34,7 @@ class Play(object): 'handlers', 'remote_user', 'remote_port', 'included_roles', 'accelerate', 'accelerate_port', 'accelerate_ipv6', 'sudo', 'sudo_user', 'transport', 'playbook', 'tags', 'gather_facts', 'serial', '_ds', '_handlers', '_tasks', - 'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct', '_play_hosts', 'su', 'su_user' + 'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct', '_play_hosts', 'su', 'su_user', 'vault_password' ] # to catch typos and so forth -- these are userland names @@ -44,12 +44,12 @@ class Play(object): 'tasks', 'handlers', 'remote_user', 'user', 'port', 'include', 'accelerate', 'accelerate_port', 'accelerate_ipv6', 'sudo', 'sudo_user', 'connection', 'tags', 'gather_facts', 'serial', 'any_errors_fatal', 'roles', 'pre_tasks', 'post_tasks', 'max_fail_percentage', - 'su', 'su_user' + 'su', 'su_user', 'vault_password' ] # ************************************************* - def __init__(self, playbook, ds, basedir): + def __init__(self, playbook, ds, basedir, vault_password=None): ''' constructor loads from a play datastructure ''' for x in ds.keys(): @@ -64,6 +64,7 @@ def __init__(self, playbook, ds, basedir): self.basedir = basedir self.roles = ds.get('roles', None) self.tags = ds.get('tags', None) + self.vault_password = vault_password if self.tags is None: self.tags = [] @@ -88,6 +89,7 @@ def __init__(self, playbook, ds, basedir): self.vars_files = ds.get('vars_files', []) if not isinstance(self.vars_files, list): raise errors.AnsibleError('vars_files must be a list') + self._update_vars_files_for_host(None) # template everything to be efficient, but do not pre-mature template @@ -124,6 +126,7 @@ def __init__(self, playbook, ds, basedir): self.max_fail_pct = int(ds.get('max_fail_percentage', 100)) self.su = ds.get('su', self.playbook.su) self.su_user = ds.get('su_user', self.playbook.su_user) + #self.vault_password = vault_password # Fail out if user specifies a sudo param with a su param in a given play if (ds.get('sudo') or ds.get('sudo_user')) and (ds.get('su') or ds.get('su_user')): @@ -540,7 +543,7 @@ def _load_tasks(self, tasks, vars=None, default_vars=None, sudo_vars=None, dirname = os.path.dirname(original_file) include_file = template(dirname, tokens[0], mv) include_filename = utils.path_dwim(dirname, include_file) - data = utils.parse_yaml_from_file(include_filename) + data = utils.parse_yaml_from_file(include_filename, vault_password=self.vault_password) if 'role_name' in x and data is not None: for x in data: if 'include' in x: @@ -652,12 +655,12 @@ def _get_vars(self): # ************************************************* - def update_vars_files(self, hosts): + def update_vars_files(self, hosts, vault_password=None): ''' calculate vars_files, which requires that setup runs first so ansible facts can be mixed in ''' # now loop through all the hosts... for h in hosts: - self._update_vars_files_for_host(h) + self._update_vars_files_for_host(h, vault_password=vault_password) # ************************************************* @@ -689,14 +692,14 @@ def _has_vars_in(self, msg): # ************************************************* - def _update_vars_files_for_host(self, host): + def _update_vars_files_for_host(self, host, vault_password=None): if type(self.vars_files) != list: self.vars_files = [ self.vars_files ] if host is not None: inject = {} - inject.update(self.playbook.inventory.get_variables(host)) + inject.update(self.playbook.inventory.get_variables(host, vault_password=vault_password)) inject.update(self.playbook.SETUP_CACHE[host]) for filename in self.vars_files: @@ -747,7 +750,7 @@ def _update_vars_files_for_host(self, host): filename4 = utils.path_dwim(self.basedir, filename3) if self._has_vars_in(filename4): continue - new_vars = utils.parse_yaml_from_file(filename4) + new_vars = utils.parse_yaml_from_file(filename4, vault_password=self.vault_password) if new_vars: if type(new_vars) != dict: raise errors.AnsibleError("%s must be stored as dictionary/hash: %s" % (filename4, type(new_vars))) diff --git a/lib/ansible/runner/__init__.py b/lib/ansible/runner/__init__.py index 189bcb3b92781c..1a8d8ca7abdfbe 100644 --- a/lib/ansible/runner/__init__.py +++ b/lib/ansible/runner/__init__.py @@ -144,6 +144,7 @@ def __init__(self, su=False, # Are we running our command via su? su_user=None, # User to su to when running command, ex: 'root' su_pass=C.DEFAULT_SU_PASS, + vault_pass=None, run_hosts=None, # an optional list of pre-calculated hosts to run on no_log=False, # option to enable/disable logging for a given task ): @@ -197,6 +198,7 @@ def __init__(self, self.su_user_var = su_user self.su_user = None self.su_pass = su_pass + self.vault_pass = vault_pass self.no_log = no_log if self.transport == 'smart': @@ -534,7 +536,7 @@ def get_flags(): def _executor_internal(self, host, new_stdin): ''' executes any module one or more times ''' - host_variables = self.inventory.get_variables(host) + host_variables = self.inventory.get_variables(host, vault_password=self.vault_pass) host_connection = host_variables.get('ansible_connection', self.transport) if host_connection in [ 'paramiko', 'paramiko_alt', 'ssh', 'ssh_old', 'accelerate' ]: port = host_variables.get('ansible_ssh_port', self.remote_port) diff --git a/lib/ansible/utils/__init__.py b/lib/ansible/utils/__init__.py index 0f3f7f36e3bd09..1fabd7f4ff9e3c 100644 --- a/lib/ansible/utils/__init__.py +++ b/lib/ansible/utils/__init__.py @@ -43,6 +43,8 @@ import sys import textwrap +import vault + VERBOSITY=0 # list of all deprecation messages to prevent duplicate display @@ -494,14 +496,22 @@ def process_yaml_error(exc, data, path=None): raise errors.AnsibleYAMLValidationFailed(msg) -def parse_yaml_from_file(path): +def parse_yaml_from_file(path, vault_password=None): ''' convert a yaml file to a data structure ''' + data = None + + #VAULT + if vault.is_encrypted(path): + data = vault.decrypt(path, vault_password) + else: + try: + data = open(path).read() + except IOError: + raise errors.AnsibleError("file could not read: %s" % path) + try: - data = file(path).read() return parse_yaml(data) - except IOError: - raise errors.AnsibleError("file not found: %s" % path) except yaml.YAMLError, exc: process_yaml_error(exc, data, path) @@ -693,6 +703,8 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False, help='ask for sudo password') parser.add_option('--ask-su-pass', default=False, dest='ask_su_pass', action='store_true', help='ask for su password') + parser.add_option('--ask-vault-pass', default=False, dest='ask_vault_pass', + action='store_true', help='ask for vault password') parser.add_option('--list-hosts', dest='listhosts', action='store_true', help='outputs a list of matching hosts; does not execute anything else') parser.add_option('-M', '--module-path', dest='module_path', @@ -751,10 +763,34 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False, return parser -def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False): +def ask_vaultpasswords(ask_vault_pass=False, ask_new_vault_pass=False, confirm_vault=False, confirm_new=False): + + vault_pass = None + new_vault_pass = None + + if ask_vault_pass: + vault_pass = getpass.getpass(prompt="Vault password: ") + + if ask_vault_pass and confirm_vault: + vault_pass2 = getpass.getpass(prompt="Retype Vault password: ") + if vault_pass != vault_pass2: + raise errors.AnsibleError("Passwords do not match") + + if ask_new_vault_pass: + new_vault_pass = getpass.getpass(prompt="New Vault password: ") + + if ask_new_vault_pass and confirm_new: + new_vault_pass2 = getpass.getpass(prompt="Retype New Vault password: ") + if new_vault_pass != new_vault_pass2: + raise errors.AnsibleError("Passwords do not match") + + return vault_pass, new_vault_pass + +def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False, ask_vault_pass=False): sshpass = None sudopass = None su_pass = None + vault_pass = None sudo_prompt = "sudo password: " su_prompt = "su password: " @@ -770,7 +806,10 @@ def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False): if ask_su_pass: su_pass = getpass.getpass(prompt=su_prompt) - return (sshpass, sudopass, su_pass) + if ask_vault_pass: + vault_pass = getpass.getpass(prompt="Vault password: ") + + return (sshpass, sudopass, su_pass, vault_pass) def do_encrypt(result, encrypt, salt_size=None, salt=None): if PASSLIB_AVAILABLE: diff --git a/lib/ansible/utils/vault.py b/lib/ansible/utils/vault.py new file mode 100644 index 00000000000000..b41d0c39886158 --- /dev/null +++ b/lib/ansible/utils/vault.py @@ -0,0 +1,450 @@ +# (c) 2014, James Tanner +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +# ansible-pull is a script that runs ansible in local mode +# after checking out a playbooks directory from source repo. There is an +# example playbook to bootstrap this script in the examples/ dir which +# installs ansible and sets it up to run on cron. + +import os +import shutil +import tempfile +from io import BytesIO +from subprocess import call +from ansible import errors +from hashlib import sha256 +from hashlib import md5 +from binascii import hexlify +from binascii import unhexlify +from ansible import constants as C + +# AES IMPORTS +try: + from Crypto.Cipher import AES as AES_ + HAS_AES = True +except ImportError: + HAS_AES = False + +HEADER='$ANSIBLE_VAULT' + +def is_encrypted(filename): + + ''' + Check a file for the encrypted header and return True or False + + The first line should start with the header + defined by the global HEADER. If true, we + assume this is a properly encrypted file. + ''' + + # read first line of the file + with open(filename) as f: + head = f.next() + + if head.startswith(HEADER): + return True + else: + return False + +def decrypt(filename, password): + + ''' + Return a decrypted string of the contents in an encrypted file + + This is used by the yaml loading code in ansible + to automatically determine the encryption type + and return a plaintext string of the unencrypted + data. + ''' + + if password is None: + raise errors.AnsibleError("A vault password must be specified to decrypt %s" % filename) + + V = Vault(filename=filename, vault_password=password) + return_data = V._decrypt_to_string() + + if not V._verify_decryption(return_data): + raise errors.AnsibleError("Decryption of %s failed" % filename) + + this_sha, return_data = V._strip_sha(return_data) + return return_data.strip() + + +class Vault(object): + def __init__(self, filename=None, cipher=None, vault_password=None): + self.filename = filename + self.vault_password = vault_password + self.cipher = cipher + self.version = '1.0' + + ############### + # PUBLIC + ############### + + def eval_header(self): + + """ Read first line of the file and parse header """ + + # read first line + with open(self.filename) as f: + #head=[f.next() for x in xrange(1)] + head = f.next() + + this_version = None + this_cipher = None + + # split segments + if len(head.split(';')) == 3: + this_version = head.split(';')[1].strip() + this_cipher = head.split(';')[2].strip() + else: + raise errors.AnsibleError("%s has an invalid header" % self.filename) + + # validate acceptable version + this_version = float(this_version) + if this_version < C.VAULT_VERSION_MIN or this_version > C.VAULT_VERSION_MAX: + raise errors.AnsibleError("%s must have a version between %s and %s " % (self.filename, + C.VAULT_VERSION_MIN, + C.VAULT_VERSION_MAX)) + # set properties + self.cipher = this_cipher + self.version = this_version + + def create(self): + """ create a new encrypted file """ + + if os.path.isfile(self.filename): + raise errors.AnsibleError("%s exists, please use 'edit' instead" % self.filename) + + # drop the user into vim on file + EDITOR = os.environ.get('EDITOR','vim') + call([EDITOR, self.filename]) + + self.encrypt() + + def decrypt(self): + """ unencrypt a file inplace """ + + if not is_encrypted(self.filename): + raise errors.AnsibleError("%s is not encrypted" % self.filename) + + # set cipher based on file header + self.eval_header() + + # decrypt it + data = self._decrypt_to_string() + + # verify sha and then strip it out + if not self._verify_decryption(data): + raise errors.AnsibleError("decryption of %s failed" % self.filename) + this_sha, clean_data = self._strip_sha(data) + + # write back to original file + f = open(self.filename, "wb") + f.write(clean_data) + f.close() + + def edit(self, filename=None, password=None, cipher=None, version=None): + + if not is_encrypted(self.filename): + raise errors.AnsibleError("%s is not encrypted" % self.filename) + + #decrypt to string + data = self._decrypt_to_string() + + # verify sha and then strip it out + if not self._verify_decryption(data): + raise errors.AnsibleError("decryption of %s failed" % self.filename) + this_sha, clean_data = self._strip_sha(data) + + # rewrite file without sha + _, in_path = tempfile.mkstemp() + f = open(in_path, "wb") + tmpdata = f.write(clean_data) + f.close() + + # drop the user into vim on the unencrypted tmp file + EDITOR = os.environ.get('EDITOR','vim') + call([EDITOR, in_path]) + + f = open(in_path, "rb") + tmpdata = f.read() + f.close() + + self._string_to_encrypted_file(tmpdata, self.filename) + + + def encrypt(self): + """ encrypt a file inplace """ + + if is_encrypted(self.filename): + raise errors.AnsibleError("%s is already encrypted" % self.filename) + + #self.eval_header() + self.__load_cipher() + + # read data + f = open(self.filename, "rb") + tmpdata = f.read() + f.close() + + self._string_to_encrypted_file(tmpdata, self.filename) + + + def rekey(self, newpassword): + + """ unencrypt file then encrypt with new password """ + + if not is_encrypted(self.filename): + raise errors.AnsibleError("%s is not encrypted" % self.filename) + + # unencrypt to string with old password + data = self._decrypt_to_string() + + # verify sha and then strip it out + if not self._verify_decryption(data): + raise errors.AnsibleError("decryption of %s failed" % self.filename) + this_sha, clean_data = self._strip_sha(data) + + # set password + self.vault_password = newpassword + + self._string_to_encrypted_file(clean_data, self.filename) + + + ############### + # PRIVATE + ############### + + def __load_cipher(self): + + """ + Load a cipher class by it's name + + This is a lightweight "plugin" implementation to allow + for future support of other cipher types + """ + + whitelist = ['AES'] + + if self.cipher in whitelist: + self.cipher_obj = None + if self.cipher in globals(): + this_cipher = globals()[self.cipher] + self.cipher_obj = this_cipher() + else: + raise errors.AnsibleError("%s cipher could not be loaded" % self.cipher) + else: + raise errors.AnsibleError("%s is not an allowed encryption cipher" % self.cipher) + + + + def _decrypt_to_string(self): + + """ decrypt file to string """ + + if not is_encrypted(self.filename): + raise errors.AnsibleError("%s is not encrypted" % self.filename) + + # figure out what this is + self.eval_header() + self.__load_cipher() + + # strip out header and unhex the file + clean_stream = self._dirty_file_to_clean_file(self.filename) + + # reset pointer + clean_stream.seek(0) + + # create a byte stream to hold unencrypted + dst = BytesIO() + + # decrypt from src stream to dst stream + self.cipher_obj.decrypt(clean_stream, dst, self.vault_password) + + # read data from the unencrypted stream + data = dst.read() + + return data + + def _dirty_file_to_clean_file(self, dirty_filename): + """ Strip out headers from a file, unhex and write to new file""" + + + _, in_path = tempfile.mkstemp() + #_, out_path = tempfile.mkstemp() + + # strip header from data, write rest to tmp file + f = open(dirty_filename, "rb") + tmpdata = f.readlines() + f.close() + + tmpheader = tmpdata[0].strip() + tmpdata = ''.join(tmpdata[1:]) + + # strip out newline, join, unhex + tmpdata = [ x.strip() for x in tmpdata ] + tmpdata = unhexlify(''.join(tmpdata)) + + # create and return stream + clean_stream = BytesIO(tmpdata) + return clean_stream + + def _clean_stream_to_dirty_stream(self, clean_stream): + + # combine header and hexlified encrypted data in 80 char columns + clean_stream.seek(0) + tmpdata = clean_stream.read() + tmpdata = hexlify(tmpdata) + tmpdata = [tmpdata[i:i+80] for i in range(0, len(tmpdata), 80)] + + dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher + "\n" + for l in tmpdata: + dirty_data += l + '\n' + + dirty_stream = BytesIO(dirty_data) + return dirty_stream + + def _string_to_encrypted_file(self, tmpdata, filename): + + """ Write a string of data to a file with the format ... + + HEADER;VERSION;CIPHER + HEX(ENCRYPTED(SHA256(STRING)+STRING)) + """ + + # sha256 the data + this_sha = sha256(tmpdata).hexdigest() + + # combine sha + data to tmpfile + tmpdata = this_sha + "\n" + tmpdata + src_stream = BytesIO(tmpdata) + dst_stream = BytesIO() + + # encrypt tmpfile + self.cipher_obj.encrypt(src_stream, dst_stream, self.password) + + # hexlify tmpfile and combine with header + dirty_stream = self._clean_stream_to_dirty_stream(dst_stream) + + if os.path.isfile(filename): + os.remove(filename) + + # write back to original file + dirty_stream.seek(0) + f = open(filename, "wb") + f.write(dirty_stream.read()) + f.close() + + + def _verify_decryption(self, data): + + """ Split data to sha/data and check the sha """ + + # split the sha and other data + this_sha, clean_data = self._strip_sha(data) + + # does the decrypted data match the sha ? + clean_sha = sha256(clean_data).hexdigest() + + # compare, return result + if this_sha == clean_sha: + return True + else: + return False + + def _strip_sha(self, data): + # is the first line a sha? + lines = data.split("\n") + this_sha = lines[0] + + clean_data = '\n'.join(lines[1:]) + return this_sha, clean_data + + +class AES(object): + + # http://stackoverflow.com/a/16761459 + + def __init__(self): + if not HAS_AES: + raise errors.AnsibleError("pycrypto is not installed. Fix this with your package manager, for instance, yum-install python-crypto OR (apt equivalent)") + + def aes_derive_key_and_iv(self, password, salt, key_length, iv_length): + + """ Create a key and an initialization vector """ + + d = d_i = '' + while len(d) < key_length + iv_length: + d_i = md5(d_i + password + salt).digest() + d += d_i + + key = d[:key_length] + iv = d[key_length:key_length+iv_length] + + return key, iv + + def encrypt(self, in_file, out_file, password, key_length=32): + + """ Read plaintext data from in_file and write encrypted to out_file """ + + bs = AES_.block_size + + # Get a block of random data. EL does not have Crypto.Random.new() + # so os.urandom is used for cross platform purposes + print "WARNING: if encryption hangs, add more entropy (suggest using mouse inputs)" + salt = os.urandom(bs - len('Salted__')) + + key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) + cipher = AES_.new(key, AES_.MODE_CBC, iv) + out_file.write('Salted__' + salt) + finished = False + while not finished: + chunk = in_file.read(1024 * bs) + if len(chunk) == 0 or len(chunk) % bs != 0: + padding_length = (bs - len(chunk) % bs) or bs + chunk += padding_length * chr(padding_length) + finished = True + out_file.write(cipher.encrypt(chunk)) + + def decrypt(self, in_file, out_file, password, key_length=32): + + """ Read encrypted data from in_file and write decrypted to out_file """ + + # http://stackoverflow.com/a/14989032 + + bs = AES_.block_size + salt = in_file.read(bs)[len('Salted__'):] + key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) + cipher = AES_.new(key, AES_.MODE_CBC, iv) + next_chunk = '' + finished = False + + out_data = '' + + while not finished: + chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs)) + if len(next_chunk) == 0: + padding_length = ord(chunk[-1]) + chunk = chunk[:-padding_length] + finished = True + out_data += chunk + + # write decrypted data to out stream + out_file.write(out_data) + + # reset the stream pointer to the beginning + if hasattr(out_file, 'seek'): + out_file.seek(0)