Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sam local invoke to retrieve account id from current logged in session #7013

Merged
merged 27 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
61a5c5f
Allow AWS sam local invoke to retrieve account id from current logged…
kevin-james-sp May 2, 2024
9605545
Moved account id code to separate function
kevin-james-sp May 3, 2024
c7732ac
Merge branch 'develop' into develop
mildaniel May 6, 2024
8ef323f
Update samcli/commands/local/cli_common/invoke_context.py
defenderkev May 7, 2024
6a39e1d
Update samcli/commands/local/cli_common/invoke_context.py
defenderkev May 7, 2024
fb33fbe
Requested changes
kevin-james-sp May 7, 2024
77543bf
unit tests
kevin-james-sp May 7, 2024
fe8ace6
Update tests/unit/commands/local/cli_common/test_invoke_context.py
defenderkev May 7, 2024
f2ee4a2
put docstring in
kevin-james-sp May 7, 2024
8da5919
requested assertion changes
kevin-james-sp May 7, 2024
0c6d025
Fix changes requested by @hawflau
defenderkev May 20, 2024
f412c84
Merge branch 'develop' into develop
defenderkev May 20, 2024
290c556
missed removing a debugging statement
defenderkev May 24, 2024
b3c5831
add cmd line params to client_provider init
defenderkev May 24, 2024
f046969
Merge branch 'develop' into develop
defenderkev May 24, 2024
c2543ce
import get boto client method directly
defenderkev May 28, 2024
3ec9847
Changed some existing tests, as the new code now means any profile sp…
defenderkev May 28, 2024
6190ad2
Merge branch 'develop' into develop
mildaniel May 28, 2024
40ee4e2
Added return type annotation
defenderkev May 28, 2024
29f539c
Merge branch 'develop' of github-as-defenderkev:defenderkev/aws-sam-c…
defenderkev May 28, 2024
e7061b3
catch another Exception type, as `make pr` was failing when no creden…
defenderkev May 28, 2024
5e7c72e
Merge branch 'develop' into develop
defenderkev May 30, 2024
dafa4c8
Merge branch 'develop' into develop
jysheng123 May 30, 2024
90bca1a
Add ClientError catch and update some of the tests
defenderkev Jun 3, 2024
1740627
Fixed the last tests - mocked out the new account_id code as it's not…
defenderkev Jun 5, 2024
35b0a22
Merge branch 'develop' into develop
hnnasit Jun 10, 2024
74d8852
Merge branch 'develop' into develop
lucashuy Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions samcli/commands/local/cli_common/invoke_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, TextIO, Tuple, Type, cast

from botocore.exceptions import ClientError, NoCredentialsError, TokenRetrievalError

from samcli.commands._utils.template import TemplateFailedParsingException, TemplateNotFoundException
from samcli.commands.exceptions import ContainersInitializationException
from samcli.commands.local.cli_common.user_exceptions import DebugContextException, InvokeContextException
Expand All @@ -20,6 +22,7 @@
from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider
from samcli.lib.utils import osutils
from samcli.lib.utils.async_utils import AsyncContext
from samcli.lib.utils.boto_utils import get_boto_client_provider_with_config
from samcli.lib.utils.packagetype import ZIP
from samcli.lib.utils.stream_writer import StreamWriter
from samcli.local.docker.exceptions import PortAlreadyInUse
Expand Down Expand Up @@ -178,6 +181,7 @@ def __init__(
self._aws_region = aws_region
self._aws_profile = aws_profile
self._shutdown = shutdown
self._add_account_id_to_global()

self._container_host = container_host
self._container_host_interface = container_host_interface
Expand Down Expand Up @@ -345,6 +349,25 @@ def _clean_running_containers_and_related_resources(self) -> None:
cast(WarmLambdaRuntime, self.lambda_runtime).clean_running_containers_and_related_resources()
cast(RefreshableSamFunctionProvider, self._function_provider).stop_observer()

def _add_account_id_to_global(self) -> None:
"""
Attempts to get the Account ID from the current session
If there is no current session, the standard parameter override for
AWS::AccountId is used
"""
client_provider = get_boto_client_provider_with_config(region=self._aws_region, profile=self._aws_profile)

sts = client_provider("sts")

try:
account_id = sts.get_caller_identity().get("Account")
if account_id:
if self._global_parameter_overrides is None:
self._global_parameter_overrides = {}
self._global_parameter_overrides["AWS::AccountId"] = account_id
except (NoCredentialsError, TokenRetrievalError, ClientError):
LOG.warning("No current session found, using default AWS::AccountId")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this function is ok to error, can you also add a generic exception handling right now? If theres no exception handling here, if an error occurs outside of the 2 defined errors, there is potential for this function to panic the execution

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I caatch these two exceptions because they can be the consequences of

  • Not having a ~/.aws/credentials or ~/.aws/config file
  • Valid configuration but token is expired
    If anything else causes an error, we don't know what's caused it and in my opinion the execution should panic because we don't know how to handle it.
    Thoughts? I mean, I'm happy to put in a generic except and continue but I'm not sure if that's what we should do here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to consider catching the ClientError exception here too. Using expired credentials, botocore returns:

botocore.exceptions.ClientError: An error occurred (ExpiredToken) when calling the GetCallerIdentity operation: The security token included in the request is expired

You had mentioned that other tests start failing once the ClientError was caught, maybe we could look at fixing those?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucashuy What I've done is add catching ClientError to the new method, and patching the new method for the existing tests so it's just a stub; they don't need a live connection so inserting a live account ID is irrelevant. All tests are passing now
I also put back in some definitions of aws_profile I had taken out of some of those tests. Now the only differences to upstream are the patching of _add_account_id_to_global and the extra tests for the new code.

@property
def function_identifier(self) -> str:
"""
Expand Down
46 changes: 40 additions & 6 deletions tests/unit/commands/local/cli_common/test_invoke_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
class TestInvokeContext__enter__(TestCase):
@patch("samcli.commands.local.cli_common.invoke_context.ContainerManager")
@patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider")
def test_must_read_from_necessary_files(self, SamFunctionProviderMock, ContainerManagerMock):
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_read_from_necessary_files(
self, _add_account_id_to_global_mock, SamFunctionProviderMock, ContainerManagerMock
):
function_provider = Mock()
function_provider.get_all.return_value = [
Mock(
Expand Down Expand Up @@ -116,8 +119,9 @@ def test_must_read_from_necessary_files(self, SamFunctionProviderMock, Container

@patch("samcli.commands.local.cli_common.invoke_context.ContainerManager")
@patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_initialize_all_containers_if_warm_containers_is_enabled(
self, RefreshableSamFunctionProviderMock, ContainerManagerMock
self, _add_account_id_to_global_mock, RefreshableSamFunctionProviderMock, ContainerManagerMock
):
function_provider = Mock()
function = Mock()
Expand Down Expand Up @@ -205,8 +209,9 @@ def test_must_initialize_all_containers_if_warm_containers_is_enabled(

@patch("samcli.commands.local.cli_common.invoke_context.ContainerManager")
@patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_set_debug_function_if_warm_containers_enabled_no_debug_function_provided_and_template_contains_one_function(
self, RefreshableSamFunctionProviderMock, ContainerManagerMock
self, _add_account_id_to_global_mock, RefreshableSamFunctionProviderMock, ContainerManagerMock
):
function_provider = Mock()
function = Mock(
Expand Down Expand Up @@ -300,8 +305,9 @@ def test_must_set_debug_function_if_warm_containers_enabled_no_debug_function_pr

@patch("samcli.commands.local.cli_common.invoke_context.ContainerManager")
@patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_no_container_will_be_initialized_if_lazy_containers_is_enabled(
self, RefreshableSamFunctionProviderMock, ContainerManagerMock
self, _add_account_id_to_global_mock, RefreshableSamFunctionProviderMock, ContainerManagerMock
):
function_provider = Mock()
function_provider.get_all.return_value = [
Expand Down Expand Up @@ -504,7 +510,8 @@ class TestInvokeContextAsContextManager(TestCase):

@patch.object(InvokeContext, "__enter__")
@patch.object(InvokeContext, "__exit__")
def test_must_work_in_with_statement(self, ExitMock, EnterMock):
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_work_in_with_statement(self, _add_account_id_to_global_mock, ExitMock, EnterMock):
context_obj = Mock()
EnterMock.return_value = context_obj

Expand Down Expand Up @@ -562,8 +569,10 @@ class TestInvokeContext_local_lambda_runner(TestCase):
@patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime")
@patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner")
@patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_create_runner(
self,
_add_account_id_to_global_mock,
SamFunctionProviderMock,
LocalLambdaMock,
LambdaRuntimeMock,
Expand Down Expand Up @@ -645,8 +654,10 @@ def test_must_create_runner(
@patch("samcli.commands.local.cli_common.invoke_context.WarmLambdaRuntime")
@patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner")
@patch("samcli.commands.local.cli_common.invoke_context.RefreshableSamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_create_runner_using_warm_containers(
self,
_add_account_id_to_global_mock,
RefreshableSamFunctionProviderMock,
LocalLambdaMock,
WarmLambdaRuntimeMock,
Expand Down Expand Up @@ -726,8 +737,10 @@ def test_must_create_runner_using_warm_containers(
@patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime")
@patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner")
@patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_create_runner_with_container_host_option(
self,
_add_account_id_to_global_mock,
SamFunctionProviderMock,
LocalLambdaMock,
LambdaRuntimeMock,
Expand Down Expand Up @@ -812,8 +825,10 @@ def test_must_create_runner_with_container_host_option(
@patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime")
@patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner")
@patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_create_runner_with_extra_hosts_option(
self,
_add_account_id_to_global_mock,
SamFunctionProviderMock,
LocalLambdaMock,
LambdaRuntimeMock,
Expand Down Expand Up @@ -901,8 +916,10 @@ def test_must_create_runner_with_extra_hosts_option(
@patch("samcli.commands.local.cli_common.invoke_context.LambdaRuntime")
@patch("samcli.commands.local.cli_common.invoke_context.LocalLambdaRunner")
@patch("samcli.commands.local.cli_common.invoke_context.SamFunctionProvider")
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_create_runner_with_invoke_image_option(
self,
_add_account_id_to_global_mock,
SamFunctionProviderMock,
LocalLambdaMock,
LambdaRuntimeMock,
Expand Down Expand Up @@ -1353,10 +1370,27 @@ def test_debugger_path_resolves(self, pathlib_mock, debug_context_mock):

class TestInvokeContext_get_stacks(TestCase):
@patch("samcli.commands.local.cli_common.invoke_context.SamLocalStackProvider.get_stacks")
def test_must_pass_custom_region(self, get_stacks_mock):
@patch("samcli.commands.local.cli_common.invoke_context.InvokeContext._add_account_id_to_global")
def test_must_pass_custom_region(self, add_account_id_to_global_mock, get_stacks_mock):
get_stacks_mock.return_value = [Mock(), []]
invoke_context = InvokeContext("template_file", aws_region="my-custom-region")
invoke_context._get_stacks()
get_stacks_mock.assert_called_with(
"template_file", parameter_overrides=None, global_parameter_overrides={"AWS::Region": "my-custom-region"}
)


class TestInvokeContext_add_account_id_to_global(TestCase):
def test_must_work_with_no_token(self):
invoke_context = InvokeContext("template_file")
invoke_context._add_account_id_to_global()
self.assertIsNone(invoke_context._global_parameter_overrides)

@patch("samcli.commands.local.cli_common.invoke_context.get_boto_client_provider_with_config")
def test_must_work_with_token(self, get_boto_client_provider_with_config_mock):
get_boto_client_provider_with_config_mock.return_value.return_value.get_caller_identity.return_value.get.return_value = (
"210987654321"
)
invoke_context = InvokeContext("template_file")
invoke_context._add_account_id_to_global()
self.assertEqual(invoke_context._global_parameter_overrides.get("AWS::AccountId"), "210987654321")
Loading