diff --git a/samcli/commands/remote/invoke/cli.py b/samcli/commands/remote/invoke/cli.py index d3ebd5d36f..4f7ed1b081 100644 --- a/samcli/commands/remote/invoke/cli.py +++ b/samcli/commands/remote/invoke/cli.py @@ -1,7 +1,6 @@ """CLI command for "invoke" command.""" import logging from io import TextIOWrapper -from typing import cast import click @@ -124,16 +123,6 @@ def do_cli( payload=event, payload_file=event_file, parameters=parameter, output_format=output_format ) - remote_invoke_result = remote_invoke_context.run(remote_invoke_input=remote_invoke_input) - - if remote_invoke_result.is_succeeded(): - LOG.debug("Invoking resource was successfull, writing response to stdout") - if remote_invoke_result.log_output: - LOG.debug("Writing log output to stderr") - remote_invoke_context.stderr.write(remote_invoke_result.log_output.encode()) - output_response = cast(str, remote_invoke_result.response) - remote_invoke_context.stdout.write(output_response.encode()) - else: - raise cast(Exception, remote_invoke_result.exception) + remote_invoke_context.run(remote_invoke_input=remote_invoke_input) except (ErrorBotoApiCallException, InvalideBotoResponseException, InvalidResourceBotoParameterException) as ex: raise UserException(str(ex), wrapped_from=ex.__class__.__name__) from ex diff --git a/samcli/commands/remote/remote_invoke_context.py b/samcli/commands/remote/remote_invoke_context.py index 254f504b33..c1ca48193d 100644 --- a/samcli/commands/remote/remote_invoke_context.py +++ b/samcli/commands/remote/remote_invoke_context.py @@ -2,6 +2,7 @@ Context object used by `sam remote invoke` command """ import logging +from dataclasses import dataclass from typing import Optional, cast from botocore.exceptions import ClientError @@ -15,7 +16,12 @@ UnsupportedServiceForRemoteInvoke, ) from samcli.lib.remote_invoke.remote_invoke_executor_factory import RemoteInvokeExecutorFactory -from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeExecutionInfo +from samcli.lib.remote_invoke.remote_invoke_executors import ( + RemoteInvokeConsumer, + RemoteInvokeExecutionInfo, + RemoteInvokeLogOutput, + RemoteInvokeResponse, +) from samcli.lib.utils import osutils from samcli.lib.utils.arn_utils import ARNParts, InvalidArnValue from samcli.lib.utils.boto_utils import BotoProviderType, get_client_error_code @@ -61,7 +67,7 @@ def __enter__(self) -> "RemoteInvokeContext": def __exit__(self, *args) -> None: pass - def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> None: """ Instantiates remote invoke executor with populated resource summary information, executes it with the provided input & returns its response back to the caller. If no executor can be instantiated it raises @@ -72,11 +78,6 @@ def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe remote_invoke_input: RemoteInvokeExecutionInfo RemoteInvokeExecutionInfo which contains the payload and other information that will be required during the invocation - - Returns - ------- - RemoteInvokeExecutionInfo - Populates result and exception info (if any) and returns back to the caller """ if not self._resource_summary: raise AmbiguousResourceForRemoteInvoke( @@ -85,13 +86,18 @@ def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe ) remote_invoke_executor_factory = RemoteInvokeExecutorFactory(self._boto_client_provider) - remote_invoke_executor = remote_invoke_executor_factory.create_remote_invoke_executor(self._resource_summary) + remote_invoke_executor = remote_invoke_executor_factory.create_remote_invoke_executor( + self._resource_summary, + remote_invoke_input.output_format, + DefaultRemoteInvokeResponseConsumer(self.stdout), + DefaultRemoteInvokeLogConsumer(self.stderr), + ) if not remote_invoke_executor: raise NoExecutorFoundForRemoteInvoke( f"Resource type {self._resource_summary.resource_type} is not supported for remote invoke" ) - return remote_invoke_executor.execute(remote_invoke_input) + remote_invoke_executor.execute(remote_invoke_input) def _populate_resource_summary(self) -> None: """ @@ -225,3 +231,27 @@ def stderr(self) -> StreamWriter: """ stream = osutils.stderr() return StreamWriter(stream, auto_flush=True) + + +@dataclass +class DefaultRemoteInvokeResponseConsumer(RemoteInvokeConsumer[RemoteInvokeResponse]): + """ + Default RemoteInvokeResponse consumer, writes given response event to the configured StreamWriter + """ + + _stream_writer: StreamWriter + + def consume(self, remote_invoke_response: RemoteInvokeResponse) -> None: + self._stream_writer.write(cast(str, remote_invoke_response.response).encode()) + + +@dataclass +class DefaultRemoteInvokeLogConsumer(RemoteInvokeConsumer[RemoteInvokeLogOutput]): + """ + Default RemoteInvokeLogOutput consumer, writes given log event to the configured StreamWriter + """ + + _stream_writer: StreamWriter + + def consume(self, remote_invoke_response: RemoteInvokeLogOutput) -> None: + self._stream_writer.write(remote_invoke_response.log_output.encode()) diff --git a/samcli/lib/remote_invoke/lambda_invoke_executors.py b/samcli/lib/remote_invoke/lambda_invoke_executors.py index 30f046127c..936cd89289 100644 --- a/samcli/lib/remote_invoke/lambda_invoke_executors.py +++ b/samcli/lib/remote_invoke/lambda_invoke_executors.py @@ -20,8 +20,11 @@ from samcli.lib.remote_invoke.remote_invoke_executors import ( BotoActionExecutor, RemoteInvokeExecutionInfo, + RemoteInvokeIterableResponseType, + RemoteInvokeLogOutput, RemoteInvokeOutputFormat, RemoteInvokeRequestResponseMapper, + RemoteInvokeResponse, ) from samcli.lib.utils import boto_utils @@ -45,10 +48,12 @@ class AbstractLambdaInvokeExecutor(BotoActionExecutor, ABC): _lambda_client: Any _function_name: str + _remote_output_format: RemoteInvokeOutputFormat - def __init__(self, lambda_client: Any, function_name: str): + def __init__(self, lambda_client: Any, function_name: str, remote_output_format: RemoteInvokeOutputFormat): self._lambda_client = lambda_client self._function_name = function_name + self._remote_output_format = remote_output_format self.request_parameters = {"InvocationType": "RequestResponse", "LogType": "Tail"} def validate_action_parameters(self, parameters: dict) -> None: @@ -65,12 +70,15 @@ def validate_action_parameters(self, parameters: dict) -> None: else: self.request_parameters[parameter_key] = parameter_value - def _execute_action(self, payload: str): + def _execute_action(self, payload: str) -> RemoteInvokeIterableResponseType: self.request_parameters[FUNCTION_NAME] = self._function_name self.request_parameters[PAYLOAD] = payload + return self._execute_lambda_invoke(payload) + + def _execute_boto_call(self, boto_client_method) -> dict: try: - return self._execute_lambda_invoke(payload) + return cast(dict, boto_client_method(**self.request_parameters)) except ParamValidationError as param_val_ex: raise InvalidResourceBotoParameterException( f"Invalid parameter key provided." @@ -86,8 +94,8 @@ def _execute_action(self, payload: str): raise ErrorBotoApiCallException(client_ex) from client_ex @abstractmethod - def _execute_lambda_invoke(self, payload: str): - pass + def _execute_lambda_invoke(self, payload: str) -> RemoteInvokeIterableResponseType: + raise NotImplementedError() class LambdaInvokeExecutor(AbstractLambdaInvokeExecutor): @@ -95,14 +103,21 @@ class LambdaInvokeExecutor(AbstractLambdaInvokeExecutor): Calls "invoke" method of "lambda" service with given input. """ - def _execute_lambda_invoke(self, payload: str) -> dict: + def _execute_lambda_invoke(self, payload: str) -> RemoteInvokeIterableResponseType: LOG.debug( "Calling lambda_client.invoke with FunctionName:%s, Payload:%s, parameters:%s", self._function_name, payload, self.request_parameters, ) - return cast(dict, self._lambda_client.invoke(**self.request_parameters)) + lambda_response = self._execute_boto_call(self._lambda_client.invoke) + if self._remote_output_format == RemoteInvokeOutputFormat.RAW: + yield RemoteInvokeResponse(lambda_response) + if self._remote_output_format == RemoteInvokeOutputFormat.DEFAULT: + log_result = lambda_response.get(LOG_RESULT) + if log_result: + yield RemoteInvokeLogOutput(base64.b64decode(log_result).decode("utf-8")) + yield RemoteInvokeResponse(cast(StreamingBody, lambda_response.get(PAYLOAD)).read().decode("utf-8")) class LambdaInvokeWithResponseStreamExecutor(AbstractLambdaInvokeExecutor): @@ -110,17 +125,29 @@ class LambdaInvokeWithResponseStreamExecutor(AbstractLambdaInvokeExecutor): Calls "invoke_with_response_stream" method of "lambda" service with given input. """ - def _execute_lambda_invoke(self, payload: str) -> dict: + def _execute_lambda_invoke(self, payload: str) -> RemoteInvokeIterableResponseType: LOG.debug( "Calling lambda_client.invoke_with_response_stream with FunctionName:%s, Payload:%s, parameters:%s", self._function_name, payload, self.request_parameters, ) - return cast(dict, self._lambda_client.invoke_with_response_stream(**self.request_parameters)) + lambda_response = self._execute_boto_call(self._lambda_client.invoke_with_response_stream) + if self._remote_output_format == RemoteInvokeOutputFormat.RAW: + yield RemoteInvokeResponse(lambda_response) + if self._remote_output_format == RemoteInvokeOutputFormat.DEFAULT: + event_stream: EventStream = lambda_response.get(EVENT_STREAM, []) + for event in event_stream: + if PAYLOAD_CHUNK in event: + yield RemoteInvokeResponse(event.get(PAYLOAD_CHUNK).get(PAYLOAD).decode("utf-8")) + if INVOKE_COMPLETE in event: + if LOG_RESULT in event.get(INVOKE_COMPLETE): + yield RemoteInvokeLogOutput( + base64.b64decode(event.get(INVOKE_COMPLETE).get(LOG_RESULT)).decode("utf-8") + ) -class DefaultConvertToJSON(RemoteInvokeRequestResponseMapper): +class DefaultConvertToJSON(RemoteInvokeRequestResponseMapper[RemoteInvokeExecutionInfo]): """ If a regular string is provided as payload, this class will convert it into a JSON object """ @@ -143,13 +170,13 @@ def map(self, test_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInf return test_input -class LambdaResponseConverter(RemoteInvokeRequestResponseMapper): +class LambdaResponseConverter(RemoteInvokeRequestResponseMapper[RemoteInvokeResponse]): """ This class helps to convert response from lambda service. Normally lambda service returns 'Payload' field as stream, this class converts that stream into string object """ - def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def map(self, remote_invoke_input: RemoteInvokeResponse) -> RemoteInvokeResponse: LOG.debug("Mapping Lambda response to string object") if not isinstance(remote_invoke_input.response, dict): raise InvalideBotoResponseException("Invalid response type received from Lambda service, expecting dict") @@ -168,7 +195,7 @@ class LambdaStreamResponseConverter(RemoteInvokeRequestResponseMapper): This mapper, gets all 'PayloadChunk's and 'InvokeComplete' events and decodes them for next mapper. """ - def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def map(self, remote_invoke_input: RemoteInvokeResponse) -> RemoteInvokeResponse: LOG.debug("Mapping Lambda response to string object") if not isinstance(remote_invoke_input.response, dict): raise InvalideBotoResponseException("Invalid response type received from Lambda service, expecting dict") @@ -180,70 +207,11 @@ def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe decoded_payload_chunk = event.get(PAYLOAD_CHUNK).get(PAYLOAD).decode("utf-8") decoded_event_stream.append({PAYLOAD_CHUNK: {PAYLOAD: decoded_payload_chunk}}) if INVOKE_COMPLETE in event: - log_output = event.get(INVOKE_COMPLETE).get(LOG_RESULT, b"") - decoded_event_stream.append({INVOKE_COMPLETE: {LOG_RESULT: log_output}}) + decoded_event_stream.append(event) remote_invoke_input.response[EVENT_STREAM] = decoded_event_stream return remote_invoke_input -class LambdaResponseOutputFormatter(RemoteInvokeRequestResponseMapper): - """ - This class helps to format output response for lambda service that will be printed on the CLI. - If LogResult is found in the response, the decoded LogResult will be written to stderr. The response payload will - be written to stdout. - """ - - def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: - """ - Maps the lambda response output to the type of output format specified as user input. - If output_format is original-boto-response, write the original boto API response - to stdout. - """ - if remote_invoke_input.output_format == RemoteInvokeOutputFormat.DEFAULT: - LOG.debug("Formatting Lambda output response") - boto_response = cast(dict, remote_invoke_input.response) - log_field = boto_response.get(LOG_RESULT) - if log_field: - log_result = base64.b64decode(log_field).decode("utf-8") - remote_invoke_input.log_output = log_result - - invocation_type_parameter = remote_invoke_input.parameters.get("InvocationType") - if invocation_type_parameter and invocation_type_parameter != "RequestResponse": - remote_invoke_input.response = {"StatusCode": boto_response["StatusCode"]} - else: - remote_invoke_input.response = boto_response.get(PAYLOAD) - - return remote_invoke_input - - -class LambdaStreamResponseOutputFormatter(RemoteInvokeRequestResponseMapper): - """ - This class helps to format streaming output response for lambda service that will be printed on the CLI. - It loops through EventStream elements and adds them to response, and once InvokeComplete is reached, it updates - log_output and response objects in remote_invoke_input. - """ - - def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: - """ - Maps the lambda response output to the type of output format specified as user input. - If output_format is original-boto-response, write the original boto API response - to stdout. - """ - if remote_invoke_input.output_format == RemoteInvokeOutputFormat.DEFAULT: - LOG.debug("Formatting Lambda output response") - boto_response = cast(dict, remote_invoke_input.response) - combined_response = "" - for event in boto_response.get(EVENT_STREAM, []): - if PAYLOAD_CHUNK in event: - payload_chunk = event.get(PAYLOAD_CHUNK).get(PAYLOAD) - combined_response = f"{combined_response}{payload_chunk}" - if INVOKE_COMPLETE in event: - log_result = base64.b64decode(event.get(INVOKE_COMPLETE).get(LOG_RESULT)).decode("utf-8") - remote_invoke_input.log_output = log_result - remote_invoke_input.response = combined_response - return remote_invoke_input - - def _is_function_invoke_mode_response_stream(lambda_client: Any, function_name: str): """ Returns True if given function has RESPONSE_STREAM as InvokeMode, False otherwise diff --git a/samcli/lib/remote_invoke/remote_invoke_executor_factory.py b/samcli/lib/remote_invoke/remote_invoke_executor_factory.py index 33ec958e1f..129f9302d9 100644 --- a/samcli/lib/remote_invoke/remote_invoke_executor_factory.py +++ b/samcli/lib/remote_invoke/remote_invoke_executor_factory.py @@ -9,12 +9,17 @@ LambdaInvokeExecutor, LambdaInvokeWithResponseStreamExecutor, LambdaResponseConverter, - LambdaResponseOutputFormatter, LambdaStreamResponseConverter, - LambdaStreamResponseOutputFormatter, _is_function_invoke_mode_response_stream, ) -from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeExecutor, ResponseObjectToJsonStringMapper +from samcli.lib.remote_invoke.remote_invoke_executors import ( + RemoteInvokeConsumer, + RemoteInvokeExecutor, + RemoteInvokeLogOutput, + RemoteInvokeOutputFormat, + RemoteInvokeResponse, + ResponseObjectToJsonStringMapper, +) from samcli.lib.utils.cloudformation import CloudFormationResourceSummary from samcli.lib.utils.resources import ( AWS_LAMBDA_FUNCTION, @@ -29,7 +34,11 @@ def __init__(self, boto_client_provider: Callable[[str], Any]): self._boto_client_provider = boto_client_provider def create_remote_invoke_executor( - self, cfn_resource_summary: CloudFormationResourceSummary + self, + cfn_resource_summary: CloudFormationResourceSummary, + output_format: RemoteInvokeOutputFormat, + response_consumer: RemoteInvokeConsumer[RemoteInvokeResponse], + log_consumer: RemoteInvokeConsumer[RemoteInvokeLogOutput], ) -> Optional[RemoteInvokeExecutor]: """ Creates remote invoker with given CloudFormationResourceSummary @@ -38,8 +47,14 @@ def create_remote_invoke_executor( ---------- cfn_resource_summary : CloudFormationResourceSummary Information about the resource, which RemoteInvokeExecutor will be created for - - Returns: + output_format: RemoteInvokeOutputFormat + Output format of the current remote invoke execution, passed down to executor itself + response_consumer: RemoteInvokeConsumer[RemoteInvokeResponse] + Consumer instance which can process RemoteInvokeResponse events + log_consumer: RemoteInvokeConsumer[RemoteInvokeLogOutput] + Consumer instance which can process RemoteInvokeLogOutput events + + Returns ------- Optional[RemoteInvokeExecutor] RemoteInvoker instance for the given CFN resource, None if the resource is not supported yet @@ -50,7 +65,7 @@ def create_remote_invoke_executor( ) if remote_invoke_executor: - return remote_invoke_executor(self, cfn_resource_summary) + return remote_invoke_executor(self, cfn_resource_summary, output_format, response_consumer, log_consumer) LOG.error( "Can't find remote invoke executor instance for resource %s for type %s", @@ -60,7 +75,13 @@ def create_remote_invoke_executor( return None - def _create_lambda_boto_executor(self, cfn_resource_summary: CloudFormationResourceSummary) -> RemoteInvokeExecutor: + def _create_lambda_boto_executor( + self, + cfn_resource_summary: CloudFormationResourceSummary, + remote_invoke_output_format: RemoteInvokeOutputFormat, + response_consumer: RemoteInvokeConsumer[RemoteInvokeResponse], + log_consumer: RemoteInvokeConsumer[RemoteInvokeLogOutput], + ) -> RemoteInvokeExecutor: """Creates a remote invoke executor for Lambda resource type based on the boto action being called. @@ -69,37 +90,55 @@ def _create_lambda_boto_executor(self, cfn_resource_summary: CloudFormationResou :return: Returns the created remote invoke Executor """ lambda_client = self._boto_client_provider("lambda") + mappers = [] if _is_function_invoke_mode_response_stream(lambda_client, cfn_resource_summary.physical_resource_id): LOG.debug("Creating response stream invocator for function %s", cfn_resource_summary.physical_resource_id) - return RemoteInvokeExecutor( - request_mappers=[DefaultConvertToJSON()], - response_mappers=[ + + if remote_invoke_output_format == RemoteInvokeOutputFormat.RAW: + mappers = [ LambdaStreamResponseConverter(), - LambdaStreamResponseOutputFormatter(), ResponseObjectToJsonStringMapper(), - ], + ] + + return RemoteInvokeExecutor( + request_mappers=[DefaultConvertToJSON()], + response_mappers=mappers, boto_action_executor=LambdaInvokeWithResponseStreamExecutor( - lambda_client, - cfn_resource_summary.physical_resource_id, + lambda_client, cfn_resource_summary.physical_resource_id, remote_invoke_output_format ), + response_consumer=response_consumer, + log_consumer=log_consumer, ) - return RemoteInvokeExecutor( - request_mappers=[DefaultConvertToJSON()], - response_mappers=[ + if remote_invoke_output_format == RemoteInvokeOutputFormat.RAW: + mappers = [ LambdaResponseConverter(), - LambdaResponseOutputFormatter(), ResponseObjectToJsonStringMapper(), - ], + ] + + return RemoteInvokeExecutor( + request_mappers=[DefaultConvertToJSON()], + response_mappers=mappers, boto_action_executor=LambdaInvokeExecutor( - lambda_client, - cfn_resource_summary.physical_resource_id, + lambda_client, cfn_resource_summary.physical_resource_id, remote_invoke_output_format ), + response_consumer=response_consumer, + log_consumer=log_consumer, ) # mapping definition for each supported resource type REMOTE_INVOKE_EXECUTOR_MAPPING: Dict[ - str, Callable[["RemoteInvokeExecutorFactory", CloudFormationResourceSummary], RemoteInvokeExecutor] + str, + Callable[ + [ + "RemoteInvokeExecutorFactory", + CloudFormationResourceSummary, + RemoteInvokeOutputFormat, + RemoteInvokeConsumer[RemoteInvokeResponse], + RemoteInvokeConsumer[RemoteInvokeLogOutput], + ], + RemoteInvokeExecutor, + ], ] = { AWS_LAMBDA_FUNCTION: _create_lambda_boto_executor, } diff --git a/samcli/lib/remote_invoke/remote_invoke_executors.py b/samcli/lib/remote_invoke/remote_invoke_executors.py index cb9eb6887b..0c69a9d5bf 100644 --- a/samcli/lib/remote_invoke/remote_invoke_executors.py +++ b/samcli/lib/remote_invoke/remote_invoke_executors.py @@ -4,14 +4,40 @@ import json import logging from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from io import TextIOWrapper from pathlib import Path -from typing import Any, Callable, List, Optional, Union, cast +from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast + +from typing_extensions import TypeAlias LOG = logging.getLogger(__name__) +@dataclass +class RemoteInvokeResponse: + """ + Dataclass that contains response object of the remote invoke execution. + dict for raw events, str for other ones + """ + + response: Union[str, dict] + + +@dataclass +class RemoteInvokeLogOutput: + """ + Dataclass that contains log objects of the remote invoke execution + """ + + log_output: str + + +# type alias to keep consistency between different places for remote invoke return type +RemoteInvokeIterableResponseType: TypeAlias = Iterable[Union[RemoteInvokeResponse, RemoteInvokeLogOutput]] + + class RemoteInvokeOutputFormat(Enum): """ Types of output formats used to by remote invoke @@ -69,7 +95,10 @@ def is_succeeded(self) -> bool: return bool(self.response) -class RemoteInvokeRequestResponseMapper(ABC): +RemoteInvokeResponseType = TypeVar("RemoteInvokeResponseType") + + +class RemoteInvokeRequestResponseMapper(Generic[RemoteInvokeResponseType]): """ Mapper definition which can be used map remote invoke requests or responses. @@ -81,7 +110,13 @@ class RemoteInvokeRequestResponseMapper(ABC): """ @abstractmethod - def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def map(self, remote_invoke_input: RemoteInvokeResponseType) -> RemoteInvokeResponseType: + raise NotImplementedError() + + +class RemoteInvokeConsumer(Generic[RemoteInvokeResponseType]): + @abstractmethod + def consume(self, remote_invoke_response: RemoteInvokeResponseType) -> None: raise NotImplementedError() @@ -90,7 +125,7 @@ class ResponseObjectToJsonStringMapper(RemoteInvokeRequestResponseMapper): Maps response object inside RemoteInvokeExecutionInfo into formatted JSON string with multiple lines """ - def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def map(self, remote_invoke_input: RemoteInvokeResponse) -> RemoteInvokeResponse: LOG.debug("Converting response object into JSON") remote_invoke_input.response = json.dumps(remote_invoke_input.response, indent=2) return remote_invoke_input @@ -103,7 +138,7 @@ class BotoActionExecutor(ABC): """ @abstractmethod - def _execute_action(self, payload: str) -> dict: + def _execute_action(self, payload: str) -> RemoteInvokeIterableResponseType: """ Specific boto3 API call implementation. @@ -128,7 +163,7 @@ def validate_action_parameters(self, parameters: dict): """ raise NotImplementedError() - def _execute_action_file(self, payload_file: TextIOWrapper) -> dict: + def _execute_action_file(self, payload_file: TextIOWrapper) -> RemoteInvokeIterableResponseType: """ Different implementation which is specific to a file path. Some boto3 APIs may accept a file path rather than a string. This implementation targets these options to support different file types @@ -147,20 +182,21 @@ def _execute_action_file(self, payload_file: TextIOWrapper) -> dict: """ return self._execute_action(payload_file.read()) - def execute(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def execute(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeIterableResponseType: """ Executes boto3 API and updates response or exception object depending on the result Parameters ---------- remote_invoke_input : RemoteInvokeExecutionInfo - RemoteInvokeExecutionInfo details which contains payload or payload file information + Remote execution details which contains payload or payload file information - Returns : RemoteInvokeExecutionInfo + Returns ------- - Updates response or exception fields of given input and returns it + RemoteInvokeIterableResponseType + Returns iterable response, see response type definition for details """ - action_executor: Callable[[Any], dict] + action_executor: Callable[[Any], Iterable[Union[RemoteInvokeResponse, RemoteInvokeLogOutput]]] payload: Union[str, Path] # if a file pointed is provided for payload, use specific payload and its function here @@ -172,13 +208,7 @@ def execute(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvok payload = cast(str, remote_invoke_input.payload) # execute boto3 API, and update result if it is successful, update exception otherwise - try: - action_response = action_executor(payload) - remote_invoke_input.response = action_response - except Exception as e: - remote_invoke_input.exception = e - - return remote_invoke_input + return action_executor(payload) class RemoteInvokeExecutor: @@ -190,21 +220,28 @@ class RemoteInvokeExecutor: Once the result is returned, if it is successful, response have been mapped with list of response mappers """ - _request_mappers: List[RemoteInvokeRequestResponseMapper] - _response_mappers: List[RemoteInvokeRequestResponseMapper] + _request_mappers: List[RemoteInvokeRequestResponseMapper[RemoteInvokeExecutionInfo]] + _response_mappers: List[RemoteInvokeRequestResponseMapper[RemoteInvokeResponse]] _boto_action_executor: BotoActionExecutor + _response_consumer: RemoteInvokeConsumer[RemoteInvokeResponse] + _log_consumer: RemoteInvokeConsumer[RemoteInvokeLogOutput] + def __init__( self, - request_mappers: List[RemoteInvokeRequestResponseMapper], - response_mappers: List[RemoteInvokeRequestResponseMapper], + request_mappers: List[RemoteInvokeRequestResponseMapper[RemoteInvokeExecutionInfo]], + response_mappers: List[RemoteInvokeRequestResponseMapper[RemoteInvokeResponse]], boto_action_executor: BotoActionExecutor, + response_consumer: RemoteInvokeConsumer[RemoteInvokeResponse], + log_consumer: RemoteInvokeConsumer[RemoteInvokeLogOutput], ): self._request_mappers = request_mappers self._response_mappers = response_mappers self._boto_action_executor = boto_action_executor + self._response_consumer = response_consumer + self._log_consumer = log_consumer - def execute(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def execute(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> None: """ First runs all mappers for request object to get the final version of it. Then validates all the input boto parameters and invokes the BotoActionExecutor to get the result @@ -212,13 +249,11 @@ def execute(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvok """ remote_invoke_input = self._map_input(remote_invoke_input) self._boto_action_executor.validate_action_parameters(remote_invoke_input.parameters) - remote_invoke_output = self._boto_action_executor.execute(remote_invoke_input) - - # call output mappers if the action is succeeded - if remote_invoke_output.is_succeeded(): - return self._map_output(remote_invoke_output) - - return remote_invoke_output + for remote_invoke_result in self._boto_action_executor.execute(remote_invoke_input): + if isinstance(remote_invoke_result, RemoteInvokeResponse): + self._response_consumer.consume(self._map_output(remote_invoke_result)) + if isinstance(remote_invoke_result, RemoteInvokeLogOutput): + self._log_consumer.consume(remote_invoke_result) def _map_input(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: """ @@ -229,26 +264,28 @@ def _map_input(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteIn remote_invoke_input : RemoteInvokeExecutionInfo Given remote invoke execution info which contains the request information - Returns : RemoteInvokeExecutionInfo + Returns ------- + RemoteInvokeExecutionInfo RemoteInvokeExecutionInfo which contains updated input payload """ for input_mapper in self._request_mappers: remote_invoke_input = input_mapper.map(remote_invoke_input) return remote_invoke_input - def _map_output(self, remote_invoke_output: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + def _map_output(self, remote_invoke_output: RemoteInvokeResponse) -> RemoteInvokeResponse: """ Maps the given response through the response mapper list. Parameters ---------- - remote_invoke_output : RemoteInvokeExecutionInfo - Given remote invoke execution info which contains the response information + remote_invoke_output : RemoteInvokeResponse + Given remote invoke response which contains the payload itself - Returns : RemoteInvokeExecutionInfo + Returns ------- - RemoteInvokeExecutionInfo which contains updated response + RemoteInvokeResponse + Returns the mapped instance of RemoteInvokeResponse, after applying all configured mappers """ for output_mapper in self._response_mappers: remote_invoke_output = output_mapper.map(remote_invoke_output) diff --git a/tests/unit/commands/remote/invoke/test_cli.py b/tests/unit/commands/remote/invoke/test_cli.py index 6e97251b5a..97aecfc721 100644 --- a/tests/unit/commands/remote/invoke/test_cli.py +++ b/tests/unit/commands/remote/invoke/test_cli.py @@ -101,10 +101,6 @@ def test_remote_invoke_command( context_mock.run.assert_called_with(remote_invoke_input=given_remote_invoke_execution_info) - if log_output: - stderr_stream_writer_mock.write.assert_called() - stdout_stream_writer_mock.write.assert_called() - @parameterized.expand( [ (InvalideBotoResponseException,), @@ -114,14 +110,9 @@ def test_remote_invoke_command( ) @patch("samcli.commands.remote.remote_invoke_context.RemoteInvokeContext") def test_raise_user_exception_invoke_not_successfull(self, exeception_to_raise, mock_invoke_context): - context_mock = Mock() mock_invoke_context.return_value.__enter__.return_value = context_mock - - given_remote_invoke_result = Mock() - given_remote_invoke_result.is_succeeded.return_value = False - context_mock.run.return_value = given_remote_invoke_result - given_remote_invoke_result.exception = exeception_to_raise + context_mock.run.side_effect = exeception_to_raise with self.assertRaises(UserException): do_cli( diff --git a/tests/unit/commands/remote/test_remote_invoke_context.py b/tests/unit/commands/remote/test_remote_invoke_context.py index 0c04a01713..e01d9de5fb 100644 --- a/tests/unit/commands/remote/test_remote_invoke_context.py +++ b/tests/unit/commands/remote/test_remote_invoke_context.py @@ -129,14 +129,11 @@ def test_running_should_execute_remote_invoke_executor_instance( mocked_remote_invoke_executor_factory = Mock() patched_remote_invoke_executor_factory.return_value = mocked_remote_invoke_executor_factory mocked_remote_invoke_executor = Mock() - mocked_output = Mock() - mocked_remote_invoke_executor.execute.return_value = mocked_output mocked_remote_invoke_executor_factory.create_remote_invoke_executor.return_value = mocked_remote_invoke_executor given_input = Mock() with self._get_remote_invoke_context() as remote_invoke_context: - remote_invoke_result = remote_invoke_context.run(given_input) + remote_invoke_context.run(given_input) mocked_remote_invoke_executor_factory.create_remote_invoke_executor.assert_called_once() mocked_remote_invoke_executor.execute.assert_called_with(given_input) - self.assertEqual(remote_invoke_result, mocked_output) diff --git a/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py b/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py index 15ff272bac..dca00cafae 100644 --- a/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py +++ b/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import Mock from parameterized import parameterized @@ -21,14 +21,12 @@ LambdaInvokeExecutor, LambdaInvokeWithResponseStreamExecutor, LambdaResponseConverter, - LambdaResponseOutputFormatter, LambdaStreamResponseConverter, - LambdaStreamResponseOutputFormatter, ParamValidationError, RemoteInvokeOutputFormat, _is_function_invoke_mode_response_stream, ) -from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeExecutionInfo +from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeExecutionInfo, RemoteInvokeResponse class CommonTestsLambdaInvokeExecutor: @@ -51,21 +49,24 @@ def test_execute_action_invalid_parameter_value_throws_client_error(self, error_ error = ClientError(error_response={"Error": {"Code": error_code}}, operation_name="invoke") self._get_boto3_method().side_effect = error with self.assertRaises(InvalidResourceBotoParameterException): - self.lambda_invoke_executor._execute_action(given_payload) + for _ in self.lambda_invoke_executor._execute_action(given_payload): + pass def test_execute_action_invalid_parameter_key_throws_parameter_validation_exception(self): given_payload = Mock() error = ParamValidationError(report="Invalid parameters") self._get_boto3_method().side_effect = error with self.assertRaises(InvalidResourceBotoParameterException): - self.lambda_invoke_executor._execute_action(given_payload) + for _ in self.lambda_invoke_executor._execute_action(given_payload): + pass def test_execute_action_throws_client_error_exception(self): - given_payload = Mock() + given_payload = "payload" error = ClientError(error_response={"Error": {"Code": "MockException"}}, operation_name="invoke") self._get_boto3_method().side_effect = error with self.assertRaises(ErrorBotoApiCallException): - self.lambda_invoke_executor._execute_action(given_payload) + for _ in self.lambda_invoke_executor._execute_action(given_payload): + pass @parameterized.expand( [ @@ -94,7 +95,9 @@ class TestLambdaInvokeExecutor(CommonTestsLambdaInvokeExecutor.AbstractLambdaInv def setUp(self) -> None: self.lambda_client = Mock() self.function_name = Mock() - self.lambda_invoke_executor = LambdaInvokeExecutor(self.lambda_client, self.function_name) + self.lambda_invoke_executor = LambdaInvokeExecutor( + self.lambda_client, self.function_name, RemoteInvokeOutputFormat.RAW + ) def test_execute_action(self): given_payload = Mock() @@ -103,7 +106,7 @@ def test_execute_action(self): result = self.lambda_invoke_executor._execute_action(given_payload) - self.assertEqual(result, given_result) + self.assertEqual(list(result), [RemoteInvokeResponse(given_result)]) self.lambda_client.invoke.assert_called_with( FunctionName=self.function_name, Payload=given_payload, InvocationType="RequestResponse", LogType="Tail" ) @@ -116,7 +119,9 @@ class TestLambdaInvokeWithResponseStreamExecutor(CommonTestsLambdaInvokeExecutor def setUp(self) -> None: self.lambda_client = Mock() self.function_name = Mock() - self.lambda_invoke_executor = LambdaInvokeWithResponseStreamExecutor(self.lambda_client, self.function_name) + self.lambda_invoke_executor = LambdaInvokeWithResponseStreamExecutor( + self.lambda_client, self.function_name, RemoteInvokeOutputFormat.RAW + ) def test_execute_action(self): given_payload = Mock() @@ -125,7 +130,7 @@ def test_execute_action(self): result = self.lambda_invoke_executor._execute_action(given_payload) - self.assertEqual(result, given_result) + self.assertEqual(list(result), [RemoteInvokeResponse(given_result)]) self.lambda_client.invoke_with_response_stream.assert_called_with( FunctionName=self.function_name, Payload=given_payload, InvocationType="RequestResponse", LogType="Tail" ) @@ -196,7 +201,9 @@ class TestLambdaStreamResponseConverter(TestCase): def setUp(self) -> None: self.lambda_stream_response_converter = LambdaStreamResponseConverter() - @parameterized.expand([({LOG_RESULT: base64.b64encode(b"log output")}, base64.b64encode(b"log output")), ({}, b"")]) + @parameterized.expand( + [({LOG_RESULT: base64.b64encode(b"log output")}, {LOG_RESULT: base64.b64encode(b"log output")}), ({}, {})] + ) def test_lambda_streaming_body_response_conversion(self, invoke_complete_response, mapped_log_response): output_format = RemoteInvokeOutputFormat.DEFAULT given_test_result = { @@ -207,20 +214,18 @@ def test_lambda_streaming_body_response_conversion(self, invoke_complete_respons {INVOKE_COMPLETE: invoke_complete_response}, ] } - remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, output_format) - remote_invoke_execution_info.response = given_test_result + remote_invoke_response = RemoteInvokeResponse(given_test_result) expected_result = { EVENT_STREAM: [ {PAYLOAD_CHUNK: {PAYLOAD: "stream1"}}, {PAYLOAD_CHUNK: {PAYLOAD: "stream2"}}, {PAYLOAD_CHUNK: {PAYLOAD: "stream3"}}, - {INVOKE_COMPLETE: {LOG_RESULT: mapped_log_response}}, + {INVOKE_COMPLETE: {**mapped_log_response}}, ] } - result = self.lambda_stream_response_converter.map(remote_invoke_execution_info) - + result = self.lambda_stream_response_converter.map(remote_invoke_response) self.assertEqual(result.response, expected_result) def test_lambda_streaming_body_invalid_response_exception(self): @@ -232,83 +237,6 @@ def test_lambda_streaming_body_invalid_response_exception(self): self.lambda_stream_response_converter.map(remote_invoke_execution_info) -class TestLambdaResponseOutputFormatter(TestCase): - def setUp(self) -> None: - self.lambda_response_converter = LambdaResponseOutputFormatter() - - def test_lambda_response_original_boto_output_formatter(self): - given_response = {"Payload": {"StatusCode": 200, "message": "hello world"}} - output_format = RemoteInvokeOutputFormat.RAW - - remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, output_format) - remote_invoke_execution_info.response = given_response - result = self.lambda_response_converter.map(remote_invoke_execution_info) - - self.assertEqual(result.response, given_response) - - @patch("samcli.lib.remote_invoke.lambda_invoke_executors.base64") - def test_lambda_response_default_output_formatter(self, base64_mock): - decoded_log_str = "decoded log string" - log_str_mock = Mock() - base64_mock.b64decode().decode.return_value = decoded_log_str - given_response = {"Payload": {"StatusCode": 200, "message": "hello world"}, "LogResult": log_str_mock} - output_format = RemoteInvokeOutputFormat.DEFAULT - - remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, output_format) - remote_invoke_execution_info.response = given_response - - expected_result = {"StatusCode": 200, "message": "hello world"} - result = self.lambda_response_converter.map(remote_invoke_execution_info) - - self.assertEqual(result.response, expected_result) - self.assertEqual(result.log_output, decoded_log_str) - - @parameterized.expand( - [ - ({"InvocationType": "DryRun", "Qualifier": "TestQualifier"},), - ({"InvocationType": "Event", "LogType": None},), - ] - ) - def test_non_default_invocation_type_output_formatter(self, parameters): - given_response = {"StatusCode": 200, "Payload": {"StatusCode": 200, "message": "hello world"}} - output_format = RemoteInvokeOutputFormat.DEFAULT - - remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, parameters, output_format) - remote_invoke_execution_info.response = given_response - - expected_result = {"StatusCode": 200} - result = self.lambda_response_converter.map(remote_invoke_execution_info) - - self.assertEqual(result.response, expected_result) - - -class TestLambdaStreamResponseOutputFormatter(TestCase): - def setUp(self) -> None: - self.lambda_response_converter = LambdaStreamResponseOutputFormatter() - - def test_none_event_stream(self): - remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, RemoteInvokeOutputFormat.DEFAULT) - remote_invoke_execution_info.response = {} - - mapped_response = self.lambda_response_converter.map(remote_invoke_execution_info) - self.assertEqual(mapped_response.response, "") - - def test_event_stream(self): - remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, RemoteInvokeOutputFormat.DEFAULT) - remote_invoke_execution_info.response = { - EVENT_STREAM: [ - {PAYLOAD_CHUNK: {PAYLOAD: "stream1"}}, - {PAYLOAD_CHUNK: {PAYLOAD: "stream2"}}, - {PAYLOAD_CHUNK: {PAYLOAD: "stream3"}}, - {INVOKE_COMPLETE: {LOG_RESULT: base64.b64encode(b"log output")}}, - ] - } - - mapped_response = self.lambda_response_converter.map(remote_invoke_execution_info) - self.assertEqual(mapped_response.response, "stream1stream2stream3") - self.assertEqual(mapped_response.log_output, "log output") - - class TestLambdaInvokeExecutorUtilities(TestCase): @parameterized.expand( [ diff --git a/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py b/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py index 3a1f938e19..bbb8c1bac9 100644 --- a/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py +++ b/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py @@ -1,3 +1,4 @@ +import itertools from unittest import TestCase from unittest.mock import patch, Mock @@ -6,6 +7,7 @@ from samcli.lib.remote_invoke.remote_invoke_executor_factory import ( RemoteInvokeExecutorFactory, ) +from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeOutputFormat class TestRemoteInvokeExecutorFactory(TestCase): @@ -24,38 +26,48 @@ def test_create_remote_invoke_executor(self, patched_executor_mapping): given_executor_creator_method.return_value = given_executor given_cfn_resource_summary = Mock() - executor = self.remote_invoke_executor_factory.create_remote_invoke_executor(given_cfn_resource_summary) + given_output_format = Mock() + given_response_consumer = Mock() + given_log_consumer = Mock() + executor = self.remote_invoke_executor_factory.create_remote_invoke_executor( + given_cfn_resource_summary, given_output_format, given_response_consumer, given_log_consumer + ) patched_executor_mapping.get.assert_called_with(given_cfn_resource_summary.resource_type) given_executor_creator_method.assert_called_with( - self.remote_invoke_executor_factory, given_cfn_resource_summary + self.remote_invoke_executor_factory, + given_cfn_resource_summary, + given_output_format, + given_response_consumer, + given_log_consumer, ) self.assertEqual(executor, given_executor) def test_failed_create_test_executor(self): given_cfn_resource_summary = Mock() - executor = self.remote_invoke_executor_factory.create_remote_invoke_executor(given_cfn_resource_summary) + executor = self.remote_invoke_executor_factory.create_remote_invoke_executor( + given_cfn_resource_summary, Mock(), Mock(), Mock() + ) self.assertIsNone(executor) - @parameterized.expand([(True,), (False,)]) + @parameterized.expand( + itertools.product([True, False], [RemoteInvokeOutputFormat.RAW, RemoteInvokeOutputFormat.DEFAULT]) + ) @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaInvokeExecutor") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaInvokeWithResponseStreamExecutor") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.DefaultConvertToJSON") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaResponseConverter") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaStreamResponseConverter") - @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaResponseOutputFormatter") - @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaStreamResponseOutputFormatter") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.ResponseObjectToJsonStringMapper") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.RemoteInvokeExecutor") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory._is_function_invoke_mode_response_stream") def test_create_lambda_test_executor( self, is_function_invoke_mode_response_stream, + remote_invoke_output_format, patched_is_function_invoke_mode_response_stream, patched_remote_invoke_executor, patched_object_to_json_converter, - patched_stream_response_output_formatter, - patched_response_output_formatter, patched_stream_response_converter, patched_response_converter, patched_convert_to_default_json, @@ -72,36 +84,51 @@ def test_create_lambda_test_executor( given_remote_invoke_executor = Mock() patched_remote_invoke_executor.return_value = given_remote_invoke_executor - lambda_executor = self.remote_invoke_executor_factory._create_lambda_boto_executor(given_cfn_resource_summary) + given_response_consumer = Mock() + given_log_consumer = Mock() + lambda_executor = self.remote_invoke_executor_factory._create_lambda_boto_executor( + given_cfn_resource_summary, remote_invoke_output_format, given_response_consumer, given_log_consumer + ) self.assertEqual(lambda_executor, given_remote_invoke_executor) self.boto_client_provider_mock.assert_called_with("lambda") patched_convert_to_default_json.assert_called_once() - patched_object_to_json_converter.assert_called_once() if is_function_invoke_mode_response_stream: - patched_stream_response_output_formatter.assert_called_once() - patched_stream_response_converter.assert_called_once() - patched_lambda_invoke_with_response_stream_executor.assert_called_once() - patched_remote_invoke_executor.assert_called_with( - request_mappers=[patched_convert_to_default_json()], - response_mappers=[ + expected_mappers = [] + if remote_invoke_output_format == RemoteInvokeOutputFormat.RAW: + patched_object_to_json_converter.assert_called_once() + patched_stream_response_converter.assert_called_once() + patched_lambda_invoke_with_response_stream_executor.assert_called_with( + given_lambda_client, given_physical_resource_id, remote_invoke_output_format + ) + expected_mappers = [ patched_stream_response_converter(), - patched_stream_response_output_formatter(), patched_object_to_json_converter(), - ], + ] + patched_remote_invoke_executor.assert_called_with( + request_mappers=[patched_convert_to_default_json()], + response_mappers=expected_mappers, boto_action_executor=patched_lambda_invoke_with_response_stream_executor(), + response_consumer=given_response_consumer, + log_consumer=given_log_consumer, ) else: - patched_response_output_formatter.assert_called_once() - patched_response_converter.assert_called_once() - patched_lambda_invoke_executor.assert_called_with(given_lambda_client, given_physical_resource_id) - patched_remote_invoke_executor.assert_called_with( - request_mappers=[patched_convert_to_default_json()], - response_mappers=[ + expected_mappers = [] + if remote_invoke_output_format == RemoteInvokeOutputFormat.RAW: + patched_object_to_json_converter.assert_called_once() + patched_response_converter.assert_called_once() + patched_lambda_invoke_executor.assert_called_with( + given_lambda_client, given_physical_resource_id, remote_invoke_output_format + ) + expected_mappers = [ patched_response_converter(), - patched_response_output_formatter(), patched_object_to_json_converter(), - ], + ] + patched_remote_invoke_executor.assert_called_with( + request_mappers=[patched_convert_to_default_json()], + response_mappers=expected_mappers, boto_action_executor=patched_lambda_invoke_executor(), + response_consumer=given_response_consumer, + log_consumer=given_log_consumer, ) diff --git a/tests/unit/lib/remote_invoke/test_remote_invoke_executors.py b/tests/unit/lib/remote_invoke/test_remote_invoke_executors.py index bb8cfebb2e..8f3ce96e46 100644 --- a/tests/unit/lib/remote_invoke/test_remote_invoke_executors.py +++ b/tests/unit/lib/remote_invoke/test_remote_invoke_executors.py @@ -11,6 +11,7 @@ ResponseObjectToJsonStringMapper, RemoteInvokeRequestResponseMapper, RemoteInvokeOutputFormat, + RemoteInvokeResponse, ) @@ -89,8 +90,6 @@ def test_execute_with_payload(self): patched_execute_action.assert_called_with(given_payload) patched_execute_action_file.assert_not_called() - self.assertEqual(given_result, result.response) - def test_execute_with_payload_file(self): given_payload_file = Mock() given_parameters = {"ExampleParameter": "ExampleValue"} @@ -108,8 +107,6 @@ def test_execute_with_payload_file(self): patched_execute_action_file.assert_called_with(given_payload_file) patched_execute_action.assert_not_called() - self.assertEqual(given_result, result.response) - def test_execute_error(self): given_payload = Mock() given_parameters = {"ExampleParameter": "ExampleValue"} @@ -120,11 +117,9 @@ def test_execute_error(self): given_exception = ValueError() patched_execute_action.side_effect = given_exception - result = self.boto_action_executor.execute(test_execution_info) - - patched_execute_action.assert_called_with(given_payload) - - self.assertEqual(given_exception, result.exception) + with self.assertRaises(ValueError): + result = self.boto_action_executor.execute(test_execution_info) + patched_execute_action.assert_called_with(given_payload) class TestRemoteInvokeExecutor(TestCase): @@ -142,20 +137,20 @@ def setUp(self) -> None: ] self.test_executor = RemoteInvokeExecutor( - self.mock_request_mappers, self.mock_response_mappers, self.mock_boto_action_executor + self.mock_request_mappers, self.mock_response_mappers, self.mock_boto_action_executor, Mock(), Mock() ) def test_execution(self): given_payload = Mock() given_parameters = {"ExampleParameter": "ExampleValue"} - given_output_format = "original-boto-response" + given_output_format = RemoteInvokeOutputFormat.RAW test_execution_info = RemoteInvokeExecutionInfo(given_payload, None, given_parameters, given_output_format) validate_action_parameters_function = Mock() self.mock_boto_action_executor.validate_action_parameters = validate_action_parameters_function + self.mock_boto_action_executor.execute.return_value = [RemoteInvokeResponse(Mock())] - result = self.test_executor.execute(remote_invoke_input=test_execution_info) + self.test_executor.execute(remote_invoke_input=test_execution_info) - self.assertIsNotNone(result) validate_action_parameters_function.assert_called_once() for request_mapper in self.mock_request_mappers: @@ -167,7 +162,7 @@ def test_execution(self): def test_execution_failure(self): given_payload = Mock() given_parameters = {"ExampleParameter": "ExampleValue"} - given_output_format = "original-boto-response" + given_output_format = RemoteInvokeOutputFormat.RAW test_execution_info = RemoteInvokeExecutionInfo(given_payload, None, given_parameters, given_output_format) validate_action_parameters_function = Mock() self.mock_boto_action_executor.validate_action_parameters = validate_action_parameters_function @@ -176,11 +171,10 @@ def test_execution_failure(self): given_payload, None, given_parameters, given_output_format ) given_result_execution_info.exception = Mock() - self.mock_boto_action_executor.execute.return_value = given_result_execution_info + self.mock_boto_action_executor.execute.return_value = [given_result_execution_info] - result = self.test_executor.execute(test_execution_info) + self.test_executor.execute(test_execution_info) - self.assertIsNotNone(result) validate_action_parameters_function.assert_called_once() for request_mapper in self.mock_request_mappers: