Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions awsshell/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ def run(self, command, application):


class WizardHandler(object):
def __init__(self, output=sys.stdout, err=sys.stderr,
loader=WizardLoader()):
def __init__(self, output=sys.stdout, err=sys.stderr, loader=None):
self._output = output
self._err = err
self._wizard_loader = loader
if self._wizard_loader is None:
self._wizard_loader = WizardLoader()

def run(self, command, application):
"""Run the specified wizard.
Expand Down
50 changes: 31 additions & 19 deletions awsshell/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,21 @@ def create_wizard(self, model):
stages = self._load_stages(model.get('Stages'), env)
return Wizard(start_stage, stages, env, self._error_handler)

def _load_stage(self, stage, env):
stage_attrs = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad that you removed that. Usually it is not great practice to define a function inside a function, especially if it will be called more than once.

'name': stage.get('Name'),
'prompt': stage.get('Prompt'),
'retrieval': stage.get('Retrieval'),
'next_stage': stage.get('NextStage'),
'resolution': stage.get('Resolution'),
'interaction': stage.get('Interaction'),
}
creator = self._cached_creator
interaction = self._interaction_loader
return Stage(env, creator, interaction, self, **stage_attrs)

def _load_stages(self, stages, env):
def load_stage(stage):
stage_attrs = {
'name': stage.get('Name'),
'prompt': stage.get('Prompt'),
'retrieval': stage.get('Retrieval'),
'next_stage': stage.get('NextStage'),
'resolution': stage.get('Resolution'),
'interaction': stage.get('Interaction'),
}
creator = self._cached_creator
loader = self._interaction_loader
return Stage(env, creator, loader, **stage_attrs)
return [load_stage(stage) for stage in stages]
return [self._load_stage(stage, env) for stage in stages]


class Wizard(object):
Expand Down Expand Up @@ -177,8 +178,10 @@ def execute(self):
raise WizardException('Stage not found: %s' % current_stage)
try:
self._push_stage(stage)
stage.execute()
stage_data = stage.execute()
current_stage = stage.get_next_stage()
if current_stage is None:
return stage_data
except Exception as err:
stages = [s.name for (s, _) in self._stage_history]
recovery = self._error_handler(err, stages)
Expand All @@ -199,9 +202,9 @@ def _pop_stages(self, stage_index):
class Stage(object):
"""The Stage object. Contains logic to run all steps of the stage."""

def __init__(self, env, creator, interaction_loader, name=None,
prompt=None, retrieval=None, next_stage=None, resolution=None,
interaction=None):
def __init__(self, env, creator, interaction_loader, wizard_loader,
name=None, prompt=None, retrieval=None, next_stage=None,
resolution=None, interaction=None):
"""Construct a new Stage object.

:type env: :class:`Environment`
Expand Down Expand Up @@ -235,6 +238,7 @@ def __init__(self, env, creator, interaction_loader, name=None,
"""
self._env = env
self._cached_creator = creator
self._wizard_loader = wizard_loader
self._interaction_loader = interaction_loader
self.name = name
self.prompt = prompt
Expand Down Expand Up @@ -270,6 +274,11 @@ def _handle_request_retrieval(self):
# execute operation passing all parameters
return operation(**parameters)

def _handle_wizard_delegation(self):
wizard_name = self.retrieval['Resource']
wizard = self._wizard_loader.load_wizard(wizard_name)
return wizard.execute()

def _handle_retrieval(self):
# In case of no retrieval, empty dict
if not self.retrieval:
Expand All @@ -278,14 +287,15 @@ def _handle_retrieval(self):
data = self._handle_static_retrieval()
elif self.retrieval['Type'] == 'Request':
data = self._handle_request_retrieval()
elif self.retrieval['Type'] == 'Wizard':

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with how you have it, but here is a cool way that we often use to avoid having to update an if statement every time you add a new type:

retrieval_type = self.retrieval.get('Type', 'default').lower()
data = getattr(self, '_handle_' + retrieval_type + '_retrieval', self._handle_default_retrieval)()

data = self._handle_wizard_delegation()
# Apply JMESPath query if given
if self.retrieval.get('Path'):
data = jmespath.search(self.retrieval['Path'], data)

return data

def _handle_interaction(self, data):

# if no interaction step, just forward data
if self.interaction is None:
return data
Expand All @@ -299,6 +309,7 @@ def _handle_resolution(self, data):
if self.resolution.get('Path'):
data = jmespath.search(self.resolution['Path'], data)
self._env.store(self.resolution['Key'], data)
return data

def get_next_stage(self):
"""Resolve the next stage name for the stage after this one.
Expand All @@ -322,7 +333,8 @@ def execute(self):
"""
retrieved_options = self._handle_retrieval()
selected_data = self._handle_interaction(retrieved_options)
self._handle_resolution(selected_data)
resolved_data = self._handle_resolution(selected_data)
return resolved_data


class Environment(object):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ def test_exit_dot_command_exits_shell():
assert mock_prompter.run.call_count == 1


def test_wizard_can_load_and_execute():
def test_wizard_can_load_and_execute(errstream):
# Proper dot command syntax should load and run a wizard
mock_loader = mock.Mock()
mock_wizard = mock_loader.load_wizard.return_value
mock_wizard.execute.return_value = {}
handler = app.WizardHandler(err=errstream, loader=mock_loader)
handler.run(['.wizard', 'wizname'], None)

Expand Down
63 changes: 55 additions & 8 deletions tests/unit/test_wizard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import mock
import pytest
import botocore.session

from botocore.loaders import Loader
from botocore.session import Session
from awsshell.utils import FileReadError
from awsshell.wizard import stage_error_handler
Expand Down Expand Up @@ -200,20 +203,20 @@ def test_basic_full_execution(wizard_spec, loader):
def test_basic_full_execution_error(wizard_spec):
# Test that the wizard can handle exceptions in stage execution
session = mock.Mock()
error_handler = mock.Mock()
error_handler.return_value = ('TestStage', 0)
error_handler = mock.Mock(side_effect=[('TestStage', 0), None])
loader = WizardLoader(session, error_handler=error_handler)
wizard_spec['Stages'][0]['NextStage'] = \
{'Type': 'Name', 'Name': 'StageTwo'}
wizard_spec['Stages'][0]['Resolution']['Path'] = '[0].Stage'
stage_three = {'Name': 'StageThree', 'Prompt': 'Text'}
wizard = loader.create_wizard(wizard_spec)
# force an exception once, let it recover, re-run
error = WizardException()
wizard.stages['StageTwo'].execute = mock.Mock(side_effect=[error, {}])
wizard.execute()
# assert error handler was called
assert error_handler.call_count == 1
# force two exceptions, recover once then fail to recover
errors = [WizardException(), TypeError()]
wizard.stages['StageTwo'].execute = mock.Mock(side_effect=errors)
with pytest.raises(TypeError):
wizard.execute()
# assert error handler was called twice
assert error_handler.call_count == 2
assert wizard.stages['StageTwo'].execute.call_count == 2


Expand Down Expand Up @@ -288,6 +291,50 @@ def test_wizard_basic_interaction(wizard_spec):
create.return_value.execute.assert_called_once_with(data)


def test_wizard_basic_delegation(wizard_spec):
main_spec = {
"StartStage": "One",
"Stages": [
{
"Name": "One",
"Prompt": "stage one",
"Retrieval": {
"Type": "Wizard",
"Resource": "SubWizard",
"Path": "FromSub"
}
}
]
}
sub_spec = {
"StartStage": "SubOne",
"Stages": [
{
"Name": "SubOne",
"Prompt": "stage one",
"Retrieval": {
"Type": "Static",
"Resource": {"FromSub": "Result from sub"}
}
}
]
}

mock_loader = mock.Mock(spec=Loader)
mock_loader.list_available_services.return_value = ['wizards']
mock_load_model = mock_loader.load_service_model
mock_load_model.return_value = sub_spec

session = botocore.session.get_session()
session.register_component('data_loader', mock_loader)
loader = WizardLoader(session)
wizard = loader.create_wizard(main_spec)

result = wizard.execute()
mock_load_model.assert_called_once_with('wizards', 'SubWizard')
assert result == 'Result from sub'


exceptions = [
BotoCoreError(),
WizardException('error'),
Expand Down