diff --git a/samcli/commands/local/cli_common/invoke_context.py b/samcli/commands/local/cli_common/invoke_context.py index 10e9fdb307..0e56e27846 100644 --- a/samcli/commands/local/cli_common/invoke_context.py +++ b/samcli/commands/local/cli_common/invoke_context.py @@ -10,6 +10,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, TextIO, Tuple, Type, cast +from botocore.exceptions import ClientError, NoCredentialsError, TokenRetrievalError + from samcli.commands._utils.template import TemplateFailedParsingException, TemplateNotFoundException from samcli.commands.exceptions import ContainersInitializationException from samcli.commands.local.cli_common.user_exceptions import DebugContextException, InvokeContextException @@ -20,6 +22,7 @@ from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider from samcli.lib.utils import osutils from samcli.lib.utils.async_utils import AsyncContext +from samcli.lib.utils.boto_utils import get_boto_client_provider_with_config from samcli.lib.utils.packagetype import ZIP from samcli.lib.utils.stream_writer import StreamWriter from samcli.local.docker.exceptions import PortAlreadyInUse @@ -178,6 +181,7 @@ def __init__( self._aws_region = aws_region self._aws_profile = aws_profile self._shutdown = shutdown + self._add_account_id_to_global() self._container_host = container_host self._container_host_interface = container_host_interface @@ -345,6 +349,25 @@ def _clean_running_containers_and_related_resources(self) -> None: cast(WarmLambdaRuntime, self.lambda_runtime).clean_running_containers_and_related_resources() cast(RefreshableSamFunctionProvider, self._function_provider).stop_observer() + def _add_account_id_to_global(self) -> None: + """ + Attempts to get the Account ID from the current session + If there is no current session, the standard parameter override for + AWS::AccountId is used + """ + client_provider = get_boto_client_provider_with_config(region=self._aws_region, profile=self._aws_profile) + + sts = client_provider("sts") + + try: + account_id = sts.get_caller_identity().get("Account") + if account_id: + if self._global_parameter_overrides is None: + self._global_parameter_overrides = {} + self._global_parameter_overrides["AWS::AccountId"] = account_id + except (NoCredentialsError, TokenRetrievalError, ClientError): + LOG.warning("No current session found, using default AWS::AccountId") + @property def function_identifier(self) -> str: """ diff --git a/tests/unit/commands/local/cli_common/test_invoke_context.py b/tests/unit/commands/local/cli_common/test_invoke_context.py index ccdfebaae7..d406df5060 100644 --- a/tests/unit/commands/local/cli_common/test_invoke_context.py +++ b/tests/unit/commands/local/cli_common/test_invoke_context.py @@ -28,7 +28,10 @@ class TestInvokeContext__enter__(TestCase): @patch("samcli.commands.local.cli_common.invoke_context.ContainerManager") @patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider") - def test_must_read_from_necessary_files(self, SamFunctionProviderMock, ContainerManagerMock): + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") + def test_must_read_from_necessary_files( + self, _add_account_id_to_global_mock, SamFunctionProviderMock, ContainerManagerMock + ): function_provider = Mock() function_provider.get_all.return_value = [ Mock( @@ -116,8 +119,9 @@ def test_must_read_from_necessary_files(self, SamFunctionProviderMock, Container @patch("samcli.commands.local.cli_common.invoke_context.ContainerManager") @patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_must_initialize_all_containers_if_warm_containers_is_enabled( - self, RefreshableSamFunctionProviderMock, ContainerManagerMock + self, _add_account_id_to_global_mock, RefreshableSamFunctionProviderMock, ContainerManagerMock ): function_provider = Mock() function = Mock() @@ -205,8 +209,9 @@ def test_must_initialize_all_containers_if_warm_containers_is_enabled( @patch("samcli.commands.local.cli_common.invoke_context.ContainerManager") @patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_must_set_debug_function_if_warm_containers_enabled_no_debug_function_provided_and_template_contains_one_function( - self, RefreshableSamFunctionProviderMock, ContainerManagerMock + self, _add_account_id_to_global_mock, RefreshableSamFunctionProviderMock, ContainerManagerMock ): function_provider = Mock() function = Mock( @@ -300,8 +305,9 @@ def test_must_set_debug_function_if_warm_containers_enabled_no_debug_function_pr @patch("samcli.commands.local.cli_common.invoke_context.ContainerManager") @patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_no_container_will_be_initialized_if_lazy_containers_is_enabled( - self, RefreshableSamFunctionProviderMock, ContainerManagerMock + self, _add_account_id_to_global_mock, RefreshableSamFunctionProviderMock, ContainerManagerMock ): function_provider = Mock() function_provider.get_all.return_value = [ @@ -504,7 +510,8 @@ class TestInvokeContextAsContextManager(TestCase): @patch.object(InvokeContext, "__enter__") @patch.object(InvokeContext, "__exit__") - def test_must_work_in_with_statement(self, ExitMock, EnterMock): + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") + def test_must_work_in_with_statement(self, _add_account_id_to_global_mock, ExitMock, EnterMock): context_obj = Mock() EnterMock.return_value = context_obj @@ -562,8 +569,10 @@ class TestInvokeContext_local_lambda_runner(TestCase): @patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime") @patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner") @patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_must_create_runner( self, + _add_account_id_to_global_mock, SamFunctionProviderMock, LocalLambdaMock, LambdaRuntimeMock, @@ -645,8 +654,10 @@ def test_must_create_runner( @patch("samcli.commands.local.cli_common.invoke_context.WarmLambdaRuntime") @patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner") @patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_must_create_runner_using_warm_containers( self, + _add_account_id_to_global_mock, RefreshableSamFunctionProviderMock, LocalLambdaMock, WarmLambdaRuntimeMock, @@ -726,8 +737,10 @@ def test_must_create_runner_using_warm_containers( @patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime") @patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner") @patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_must_create_runner_with_container_host_option( self, + _add_account_id_to_global_mock, SamFunctionProviderMock, LocalLambdaMock, LambdaRuntimeMock, @@ -812,8 +825,10 @@ def test_must_create_runner_with_container_host_option( @patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime") @patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner") @patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_must_create_runner_with_extra_hosts_option( self, + _add_account_id_to_global_mock, SamFunctionProviderMock, LocalLambdaMock, LambdaRuntimeMock, @@ -901,8 +916,10 @@ def test_must_create_runner_with_extra_hosts_option( @patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime") @patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner") @patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider") + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") def test_must_create_runner_with_invoke_image_option( self, + _add_account_id_to_global_mock, SamFunctionProviderMock, LocalLambdaMock, LambdaRuntimeMock, @@ -1353,10 +1370,27 @@ def test_debugger_path_resolves(self, pathlib_mock, debug_context_mock): class TestInvokeContext_get_stacks(TestCase): @patch("samcli.commands.local.cli_common.invoke_context.SamLocalStackProvider.get_stacks") - def test_must_pass_custom_region(self, get_stacks_mock): + @patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global") + def test_must_pass_custom_region(self, add_account_id_to_global_mock, get_stacks_mock): get_stacks_mock.return_value = [Mock(), []] invoke_context = InvokeContext("template_file", aws_region="my-custom-region") invoke_context._get_stacks() get_stacks_mock.assert_called_with( "template_file", parameter_overrides=None, global_parameter_overrides={"AWS::Region": "my-custom-region"} ) + + +class TestInvokeContext_add_account_id_to_global(TestCase): + def test_must_work_with_no_token(self): + invoke_context = InvokeContext("template_file") + invoke_context._add_account_id_to_global() + self.assertIsNone(invoke_context._global_parameter_overrides) + + @patch("samcli.commands.local.cli_common.invoke_context.get_boto_client_provider_with_config") + def test_must_work_with_token(self, get_boto_client_provider_with_config_mock): + get_boto_client_provider_with_config_mock.return_value.return_value.get_caller_identity.return_value.get.return_value = ( + "210987654321" + ) + invoke_context = InvokeContext("template_file") + invoke_context._add_account_id_to_global() + self.assertEqual(invoke_context._global_parameter_overrides.get("AWS::AccountId"), "210987654321")