Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'pause' action plugin and support plugins skipping the host loop. #1089

Merged
merged 1 commit into from
Sep 26, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 20 additions & 1 deletion lib/ansible/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,26 @@ def run(self):

hosts = [ (self,x) for x in hosts ]
results = None
if self.forks > 1:

# Check if this is an action plugin. Some of them are designed
# to be ran once per group of hosts. Example module: pause,
# run once per hostgroup, rather than pausing once per each
# host.
p = self.action_plugins.get(self.module_name, None)
if p and getattr(p, 'BYPASS_HOST_LOOP', None):
# Expose the current hostgroup to the bypassing plugins
self.host_set = hosts
# We aren't iterating over all the hosts in this
# group. So, just pick the first host in our group to
# construct the conn object with.
result_data = self._executor(hosts[0][1]).result
# Create a ResultData item for each host in this group
# using the returned result. If we didn't do this we would
# get false reports of dark hosts.
results = [ ReturnData(host=h[1], result=result_data, comm_ok=True) \
for h in hosts ]
del self.host_set
elif self.forks > 1:
results = self._parallel_exec(hosts)
else:
results = [ self._executor(h[1]) for h in hosts ]
Expand Down
131 changes: 131 additions & 0 deletions lib/ansible/runner/action_plugins/pause.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2012, Tim Bielawa <tbielawa@redhat.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.

from ansible.callbacks import vv
from ansible.errors import AnsibleError as ae
from ansible.runner.return_data import ReturnData
from ansible.utils import getch
from termios import tcflush, TCIFLUSH
import datetime
import sys
import time


class ActionModule(object):
''' pauses execution for a length or time, or until input is received '''

PAUSE_TYPES = ['seconds', 'minutes', 'prompt', '']
BYPASS_HOST_LOOP = True

def __init__(self, runner):
self.runner = runner
# Set defaults
self.duration_unit = 'minutes'
self.prompt = None
self.seconds = None
self.result = {'changed': False,
'rc': 0,
'stderr': '',
'stdout': '',
'start': None,
'stop': None,
'delta': None,
}

def run(self, conn, tmp, module_name, module_args, inject):
''' run the pause actionmodule '''
args = self.runner.module_args
hosts = ', '.join(map(lambda x: x[1], self.runner.host_set))

(self.pause_type, sep, pause_params) = args.partition('=')

if self.pause_type == '':
self.pause_type = 'prompt'
elif not self.pause_type in self.PAUSE_TYPES:
raise ae("invalid parameter for pause, '%s'. must be one of: %s" % \
(self.pause_type, ", ".join(self.PAUSE_TYPES)))

# error checking
if self.pause_type in ['minutes', 'seconds']:
try:
int(pause_params)
except ValueError:
raise ae("value given to %s parameter invalid: '%s', must be an integer" % \
self.pause_type, pause_params)

# The time() command operates in seconds so we need to
# recalculate for minutes=X values.
if self.pause_type == 'minutes':
self.seconds = int(pause_params) * 60
elif self.pause_type == 'seconds':
self.seconds = int(pause_params)
self.duration_unit = 'seconds'
else:
# if no args are given we pause with a prompt
if pause_params == '':
self.prompt = "[%s]\nPress enter to continue: " % hosts
else:
self.prompt = "[%s]\n%s: " % (hosts, pause_params)

vv("created 'pause' ActionModule: pause_type=%s, duration_unit=%s, calculated_seconds=%s, prompt=%s" % \
(self.pause_type, self.duration_unit, self.seconds, self.prompt))

try:
self._start()
if not self.pause_type == 'prompt':
print "[%s]\nPausing for %s seconds" % (hosts, self.seconds)
time.sleep(self.seconds)
else:
# Clear out any unflushed buffered input which would
# otherwise be consumed by raw_input() prematurely.
tcflush(sys.stdin, TCIFLUSH)
raw_input(self.prompt)
except KeyboardInterrupt:
while True:
print '\nAction? (a)bort/(c)ontinue: '
c = getch()
if c == 'c':
# continue playbook evaluation
break
elif c == 'a':
# abort further playbook evaluation
raise ae('user requested abort!')
finally:
self._stop()

return ReturnData(conn=conn, result=self.result)

def _start(self):
''' mark the time of execution for duration calculations later '''
self.start = time.time()
self.result['start'] = str(datetime.datetime.now())
if not self.pause_type == 'prompt':
print "(^C-c = continue early, ^C-a = abort)"

def _stop(self):
''' calculate the duration we actually paused for and then
finish building the task result string '''
duration = time.time() - self.start
self.result['stop'] = str(datetime.datetime.now())
self.result['delta'] = int(duration)

if self.duration_unit == 'minutes':
duration = round(duration / 60.0, 2)
else:
duration = round(duration, 2)

self.result['stdout'] = "Paused for %s %s" % (duration, self.duration_unit)
31 changes: 21 additions & 10 deletions lib/ansible/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import glob
import subprocess
import stat
import termios
import tty

VERBOSITY=0

Expand Down Expand Up @@ -375,7 +377,7 @@ def _gitinfo():
''' returns a string containing git branch, commit id and commit date '''
result = None
repo_path = os.path.join(os.path.dirname(__file__), '..', '..', '.git')

if os.path.exists(repo_path):
# Check if the .git is a file. If it is a file, it means that we are in a submodule structure.
if os.path.isfile(repo_path):
Expand All @@ -397,7 +399,7 @@ def _gitinfo():
commit = f.readline()[:10]
f.close()
date = time.localtime(os.stat(branch_path).st_mtime)
if time.daylight == 0:
if time.daylight == 0:
offset = time.timezone
else:
offset = time.altzone
Expand All @@ -414,6 +416,17 @@ def version(prog):
result = result + " {0}".format(gitinfo)
return result

def getch():
''' read in a single character '''
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch

####################################################################
# option handling code for /usr/bin/ansible and ansible-playbook
# below this line
Expand All @@ -429,7 +442,7 @@ def increment_debug(option, opt, value, parser):
global VERBOSITY
VERBOSITY += 1

def base_parser(constants=C, usage="", output_opts=False, runas_opts=False,
def base_parser(constants=C, usage="", output_opts=False, runas_opts=False,
async_opts=False, connect_opts=False, subset_opts=False):
''' create an options parser for any ansible script '''

Expand Down Expand Up @@ -515,15 +528,15 @@ def last_non_blank_line(buf):
if (len(line) > 0):
return line
# shouldn't occur unless there's no output
return ""
return ""

def filter_leading_non_json_lines(buf):
'''
'''
used to avoid random output from SSH at the top of JSON output, like messages from
tcagetattr, or where dropbear spews MOTD on every single command (which is nuts).

need to filter anything which starts not with '{', '[', ', '=' or is an empty line.
filter only leading lines since multiline JSON is valid.
filter only leading lines since multiline JSON is valid.
'''

filtered_lines = StringIO.StringIO()
Expand All @@ -536,12 +549,10 @@ def filter_leading_non_json_lines(buf):

def import_plugins(directory):
modules = {}
for path in glob.glob(os.path.join(directory, '*.py')):
for path in glob.glob(os.path.join(directory, '*.py')):
if path.startswith("_"):
continue
name, ext = os.path.splitext(os.path.basename(path))
if not name.startswith("_"):
modules[name] = imp.load_source(name, path)
return modules