diff --git a/awsshell/wizard.py b/awsshell/wizard.py index 54c18b8..fdbaa78 100644 --- a/awsshell/wizard.py +++ b/awsshell/wizard.py @@ -1,3 +1,4 @@ +import six import sys import copy import logging @@ -18,6 +19,87 @@ LOG = logging.getLogger(__name__) +class ParamCoercion(object): + """This class coerces string parameters into the correct type. + + By default this converts strings to numerical values if the input + parameters model indicates that the field should be a number. This is to + compensate for the fact that values taken in from prompts will always be + strings and avoids having to create specific interactions for simple + conversions or having to specify the type in the wizard specification. + """ + + _DEFAULT_DICT = { + 'integer': int, + 'float': float, + 'double': float, + 'long': int + } + + def __init__(self, type_dict=_DEFAULT_DICT): + """Initialize a ParamCoercion object. + + :type type_dict: dict + :param type_dict: (Optional) A dictionary of converstions. Keys are + strings representing the shape type name and the values are callables + that given a string will return an instance of an appropriate type for + that shape type. Defaults to only coerce numbers. + """ + self._type_dict = type_dict + + def coerce(self, params, shape): + """Coerce the params according to the given shape. + + :type params: dict + :param params: The parameters to be given to an operation call. + + :type shape: :class:`botocore.model.Shape` + :param shape: The input shape for the desired operation. + + :rtype: dict + :return: The coerced version of the params. + """ + name = shape.type_name + if isinstance(params, dict) and name == 'structure': + return self._coerce_structure(params, shape) + elif isinstance(params, dict) and name == 'map': + return self._coerce_map(params, shape) + elif isinstance(params, (list, tuple)) and name == 'list': + return self._coerce_list(params, shape) + elif isinstance(params, six.string_types) and name in self._type_dict: + target_type = self._type_dict[shape.type_name] + return self._coerce_field(params, target_type) + return params + + def _coerce_structure(self, params, shape): + members = shape.members + coerced = {} + for param in members: + if param in params: + coerced[param] = self.coerce(params[param], members[param]) + return coerced + + def _coerce_map(self, params, shape): + coerced = {} + for key, value in params.items(): + coerced_key = self.coerce(key, shape.key) + coerced[coerced_key] = self.coerce(value, shape.value) + return coerced + + def _coerce_list(self, list_param, shape): + member_shape = shape.member + coerced_list = [] + for item in list_param: + coerced_list.append(self.coerce(item, member_shape)) + return coerced_list + + def _coerce_field(self, value, target_type): + try: + return target_type(value) + except ValueError: + return value + + def stage_error_handler(error, stages, confirm=confirm, prompt=select_prompt): managed_errors = ( ClientError, @@ -264,6 +346,8 @@ def _handle_request_retrieval(self): self._env.resolve_parameters(req.get('EnvParameters', {})) # union of parameters and env_parameters, conflicts favor env params parameters = dict(parameters, **env_parameters) + model = client.meta.service_model.operation_model(req['Operation']) + parameters = ParamCoercion().coerce(parameters, model.input_shape) # if the operation supports pagination, load all results upfront if client.can_paginate(operation_name): # get paginator and create iterator diff --git a/tests/unit/test_wizard.py b/tests/unit/test_wizard.py index 450bf38..78807d3 100644 --- a/tests/unit/test_wizard.py +++ b/tests/unit/test_wizard.py @@ -3,9 +3,10 @@ import botocore.session from botocore.loaders import Loader +from botocore import model from botocore.session import Session from awsshell.utils import FileReadError -from awsshell.wizard import stage_error_handler +from awsshell.wizard import stage_error_handler, ParamCoercion from awsshell.interaction import InteractionException from botocore.exceptions import ClientError, BotoCoreError from awsshell.wizard import Environment, WizardLoader, WizardException @@ -383,3 +384,60 @@ def test_stage_exception_handler_other(error_class): err = error_class() res = stage_error_handler(err, ['stage'], confirm=confirm, prompt=prompt) assert res is None + + +@pytest.fixture +def test_shape(): + shapes = { + "TestShape": { + "type": "structure", + "members": { + "Huge": {"shape": "Long"}, + "Map": {"shape": "TestMap"}, + "Scale": {"shape": "Double"}, + "Count": {"shape": "Integer"}, + "Items": {"shape": "TestList"} + } + }, + "TestList": { + "type": "list", + "member": { + "shape": "Float" + } + }, + "TestMap": { + "type": "map", + "key": {"shape": "Double"}, + "value": {"shape": "Integer"} + }, + "Long": {"type": "long"}, + "Float": {"type": "float"}, + "Double": {"type": "double"}, + "String": {"type": "string"}, + "Integer": {"type": "integer"} + } + return model.ShapeResolver(shapes).get_shape_by_name('TestShape') + + +def test_param_coercion_numbers(test_shape): + # verify coercion will convert strings to numbers according to shape + params = { + "Count": "5", + "Scale": "2.3", + "Items": ["5", "3.14"], + "Huge": "92233720368547758070", + "Map": {"2": "12"} + } + coerced = ParamCoercion().coerce(params, test_shape) + assert isinstance(coerced['Count'], int) + assert isinstance(coerced['Scale'], float) + assert all(isinstance(item, float) for item in coerced['Items']) + assert coerced['Map'][2] == 12 + assert coerced['Huge'] == 92233720368547758070 + + +def test_param_coercion_failure(test_shape): + # verify coercion leaves the field the same when it fails + params = {"Count": "fifty"} + coerced = ParamCoercion().coerce(params, test_shape) + assert coerced["Count"] == params["Count"]