diff --git a/awsshell/app.py b/awsshell/app.py index b4278ad..b470b50 100644 --- a/awsshell/app.py +++ b/awsshell/app.py @@ -149,11 +149,12 @@ def run(self, command, application): class WizardHandler(object): - def __init__(self, output=sys.stdout, err=sys.stderr, - loader=WizardLoader()): + def __init__(self, output=sys.stdout, err=sys.stderr, loader=None): self._output = output self._err = err self._wizard_loader = loader + if self._wizard_loader is None: + self._wizard_loader = WizardLoader() def run(self, command, application): """Run the specified wizard. diff --git a/awsshell/wizard.py b/awsshell/wizard.py index c5534bf..6effa53 100644 --- a/awsshell/wizard.py +++ b/awsshell/wizard.py @@ -111,20 +111,21 @@ def create_wizard(self, model): stages = self._load_stages(model.get('Stages'), env) return Wizard(start_stage, stages, env, self._error_handler) + def _load_stage(self, stage, env): + stage_attrs = { + 'name': stage.get('Name'), + 'prompt': stage.get('Prompt'), + 'retrieval': stage.get('Retrieval'), + 'next_stage': stage.get('NextStage'), + 'resolution': stage.get('Resolution'), + 'interaction': stage.get('Interaction'), + } + creator = self._cached_creator + interaction = self._interaction_loader + return Stage(env, creator, interaction, self, **stage_attrs) + def _load_stages(self, stages, env): - def load_stage(stage): - stage_attrs = { - 'name': stage.get('Name'), - 'prompt': stage.get('Prompt'), - 'retrieval': stage.get('Retrieval'), - 'next_stage': stage.get('NextStage'), - 'resolution': stage.get('Resolution'), - 'interaction': stage.get('Interaction'), - } - creator = self._cached_creator - loader = self._interaction_loader - return Stage(env, creator, loader, **stage_attrs) - return [load_stage(stage) for stage in stages] + return [self._load_stage(stage, env) for stage in stages] class Wizard(object): @@ -177,8 +178,10 @@ def execute(self): raise WizardException('Stage not found: %s' % current_stage) try: self._push_stage(stage) - stage.execute() + stage_data = stage.execute() current_stage = stage.get_next_stage() + if current_stage is None: + return stage_data except Exception as err: stages = [s.name for (s, _) in self._stage_history] recovery = self._error_handler(err, stages) @@ -199,9 +202,9 @@ def _pop_stages(self, stage_index): class Stage(object): """The Stage object. Contains logic to run all steps of the stage.""" - def __init__(self, env, creator, interaction_loader, name=None, - prompt=None, retrieval=None, next_stage=None, resolution=None, - interaction=None): + def __init__(self, env, creator, interaction_loader, wizard_loader, + name=None, prompt=None, retrieval=None, next_stage=None, + resolution=None, interaction=None): """Construct a new Stage object. :type env: :class:`Environment` @@ -235,6 +238,7 @@ def __init__(self, env, creator, interaction_loader, name=None, """ self._env = env self._cached_creator = creator + self._wizard_loader = wizard_loader self._interaction_loader = interaction_loader self.name = name self.prompt = prompt @@ -270,6 +274,11 @@ def _handle_request_retrieval(self): # execute operation passing all parameters return operation(**parameters) + def _handle_wizard_delegation(self): + wizard_name = self.retrieval['Resource'] + wizard = self._wizard_loader.load_wizard(wizard_name) + return wizard.execute() + def _handle_retrieval(self): # In case of no retrieval, empty dict if not self.retrieval: @@ -278,6 +287,8 @@ def _handle_retrieval(self): data = self._handle_static_retrieval() elif self.retrieval['Type'] == 'Request': data = self._handle_request_retrieval() + elif self.retrieval['Type'] == 'Wizard': + data = self._handle_wizard_delegation() # Apply JMESPath query if given if self.retrieval.get('Path'): data = jmespath.search(self.retrieval['Path'], data) @@ -285,7 +296,6 @@ def _handle_retrieval(self): return data def _handle_interaction(self, data): - # if no interaction step, just forward data if self.interaction is None: return data @@ -299,6 +309,7 @@ def _handle_resolution(self, data): if self.resolution.get('Path'): data = jmespath.search(self.resolution['Path'], data) self._env.store(self.resolution['Key'], data) + return data def get_next_stage(self): """Resolve the next stage name for the stage after this one. @@ -322,7 +333,8 @@ def execute(self): """ retrieved_options = self._handle_retrieval() selected_data = self._handle_interaction(retrieved_options) - self._handle_resolution(selected_data) + resolved_data = self._handle_resolution(selected_data) + return resolved_data class Environment(object): diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 5c5f2ae..c0d7171 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -154,10 +154,11 @@ def test_exit_dot_command_exits_shell(): assert mock_prompter.run.call_count == 1 -def test_wizard_can_load_and_execute(): +def test_wizard_can_load_and_execute(errstream): # Proper dot command syntax should load and run a wizard mock_loader = mock.Mock() mock_wizard = mock_loader.load_wizard.return_value + mock_wizard.execute.return_value = {} handler = app.WizardHandler(err=errstream, loader=mock_loader) handler.run(['.wizard', 'wizname'], None) diff --git a/tests/unit/test_wizard.py b/tests/unit/test_wizard.py index 98331d3..28ddef6 100644 --- a/tests/unit/test_wizard.py +++ b/tests/unit/test_wizard.py @@ -1,5 +1,8 @@ import mock import pytest +import botocore.session + +from botocore.loaders import Loader from botocore.session import Session from awsshell.utils import FileReadError from awsshell.wizard import stage_error_handler @@ -200,20 +203,20 @@ def test_basic_full_execution(wizard_spec, loader): def test_basic_full_execution_error(wizard_spec): # Test that the wizard can handle exceptions in stage execution session = mock.Mock() - error_handler = mock.Mock() - error_handler.return_value = ('TestStage', 0) + error_handler = mock.Mock(side_effect=[('TestStage', 0), None]) loader = WizardLoader(session, error_handler=error_handler) wizard_spec['Stages'][0]['NextStage'] = \ {'Type': 'Name', 'Name': 'StageTwo'} wizard_spec['Stages'][0]['Resolution']['Path'] = '[0].Stage' stage_three = {'Name': 'StageThree', 'Prompt': 'Text'} wizard = loader.create_wizard(wizard_spec) - # force an exception once, let it recover, re-run - error = WizardException() - wizard.stages['StageTwo'].execute = mock.Mock(side_effect=[error, {}]) - wizard.execute() - # assert error handler was called - assert error_handler.call_count == 1 + # force two exceptions, recover once then fail to recover + errors = [WizardException(), TypeError()] + wizard.stages['StageTwo'].execute = mock.Mock(side_effect=errors) + with pytest.raises(TypeError): + wizard.execute() + # assert error handler was called twice + assert error_handler.call_count == 2 assert wizard.stages['StageTwo'].execute.call_count == 2 @@ -288,6 +291,50 @@ def test_wizard_basic_interaction(wizard_spec): create.return_value.execute.assert_called_once_with(data) +def test_wizard_basic_delegation(wizard_spec): + main_spec = { + "StartStage": "One", + "Stages": [ + { + "Name": "One", + "Prompt": "stage one", + "Retrieval": { + "Type": "Wizard", + "Resource": "SubWizard", + "Path": "FromSub" + } + } + ] + } + sub_spec = { + "StartStage": "SubOne", + "Stages": [ + { + "Name": "SubOne", + "Prompt": "stage one", + "Retrieval": { + "Type": "Static", + "Resource": {"FromSub": "Result from sub"} + } + } + ] + } + + mock_loader = mock.Mock(spec=Loader) + mock_loader.list_available_services.return_value = ['wizards'] + mock_load_model = mock_loader.load_service_model + mock_load_model.return_value = sub_spec + + session = botocore.session.get_session() + session.register_component('data_loader', mock_loader) + loader = WizardLoader(session) + wizard = loader.create_wizard(main_spec) + + result = wizard.execute() + mock_load_model.assert_called_once_with('wizards', 'SubWizard') + assert result == 'Result from sub' + + exceptions = [ BotoCoreError(), WizardException('error'),