Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions samtranslator/translator/arn_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import lru_cache

import boto3

from typing import Optional
Expand All @@ -7,6 +9,29 @@ class NoRegionFound(Exception):
pass


@lru_cache(maxsize=1) # Only need to cache one as once deployed, it is not gonna deal with another region.
def _get_region_from_session() -> str:
return boto3.session.Session().region_name


@lru_cache(maxsize=1) # Only need to cache one as once deployed, it is not gonna deal with another region.
def _region_to_partition(region: str) -> str:
# setting default partition to aws, this will be overwritten by checking the region below
partition = "aws"

region_string = region.lower()
if region_string.startswith("cn-"):
partition = "aws-cn"
elif region_string.startswith("us-iso-"):
partition = "aws-iso"
elif region_string.startswith("us-isob"):
partition = "aws-iso-b"
elif region_string.startswith("us-gov"):
partition = "aws-us-gov"

return partition


class ArnGenerator(object):
BOTO_SESSION_REGION_NAME = None

Expand Down Expand Up @@ -53,7 +78,7 @@ def get_partition_name(cls, region: Optional[str] = None) -> str:
# mechanism, starting from AWS_DEFAULT_REGION environment variable.

if ArnGenerator.BOTO_SESSION_REGION_NAME is None:
region = boto3.session.Session().region_name
region = _get_region_from_session()
else:
region = ArnGenerator.BOTO_SESSION_REGION_NAME # type: ignore[unreachable]

Expand All @@ -63,17 +88,4 @@ def get_partition_name(cls, region: Optional[str] = None) -> str:
if region is None:
raise NoRegionFound("AWS Region cannot be found")

# setting default partition to aws, this will be overwritten by checking the region below
partition = "aws"

region_string = region.lower()
if region_string.startswith("cn-"):
partition = "aws-cn"
elif region_string.startswith("us-iso-"):
partition = "aws-iso"
elif region_string.startswith("us-isob"):
partition = "aws-iso-b"
elif region_string.startswith("us-gov"):
partition = "aws-us-gov"

return partition
return _region_to_partition(region)
4 changes: 2 additions & 2 deletions tests/translator/test_arn_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import TestCase
from parameterized import parameterized
from unittest.mock import patch
from unittest.mock import Mock, patch

from samtranslator.translator.arn_generator import ArnGenerator, NoRegionFound

Expand All @@ -17,7 +17,7 @@ def test_get_partition_name(self, region, expected):

self.assertEqual(actual, expected)

@patch("boto3.session.Session.region_name", None)
@patch("samtranslator.translator.arn_generator._get_region_from_session", Mock(return_value=None))
def test_get_partition_name_raise_NoRegionFound(self):
with self.assertRaises(NoRegionFound):
ArnGenerator.get_partition_name(None)
Expand Down
7 changes: 5 additions & 2 deletions tests/translator/test_resource_level_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,13 @@ class TestResourceLevelAttributes(AbstractTestTranslator):
"samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call",
mock_sar_service_call,
)
@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region)
def test_transform_with_additional_resource_level_attributes(self, testcase, partition_with_region):
@patch("samtranslator.translator.arn_generator._get_region_from_session")
def test_transform_with_additional_resource_level_attributes(
self, testcase, partition_with_region, mock_get_region_from_session
):
partition = partition_with_region[0]
region = partition_with_region[1]
mock_get_region_from_session.return_value = region

# add resource level attributes to input resources
manifest = self._read_input(testcase)
Expand Down
20 changes: 12 additions & 8 deletions tests/translator/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,11 @@ class TestTranslatorEndToEnd(AbstractTestTranslator):
"samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call",
mock_sar_service_call,
)
@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region)
def test_transform_success(self, testcase, partition_with_region):
@patch("samtranslator.translator.arn_generator._get_region_from_session")
def test_transform_success(self, testcase, partition_with_region, mock_get_region_from_session):
partition = partition_with_region[0]
region = partition_with_region[1]
mock_get_region_from_session.return_value = region

manifest = self._read_input(testcase)
expected = self._read_expected_output(testcase, partition)
Expand Down Expand Up @@ -338,10 +339,11 @@ def test_transform_success(self, testcase, partition_with_region):
"samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call",
mock_sar_service_call,
)
@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region)
def test_transform_success_openapi3(self, testcase, partition_with_region):
@patch("samtranslator.translator.arn_generator._get_region_from_session")
def test_transform_success_openapi3(self, testcase, partition_with_region, mock_get_region_from_session):
partition = partition_with_region[0]
region = partition_with_region[1]
mock_get_region_from_session.return_value = region

manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + ".yaml"), "r"))
# To uncover unicode-related bugs, convert dict to JSON string and parse JSON back to dict
Expand Down Expand Up @@ -393,10 +395,11 @@ def test_transform_success_openapi3(self, testcase, partition_with_region):
"samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call",
mock_sar_service_call,
)
@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region)
def test_transform_success_resource_policy(self, testcase, partition_with_region):
@patch("samtranslator.translator.arn_generator._get_region_from_session")
def test_transform_success_resource_policy(self, testcase, partition_with_region, mock_get_region_from_session):
partition = partition_with_region[0]
region = partition_with_region[1]
mock_get_region_from_session.return_value = region

manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + ".yaml"), "r"))
# To uncover unicode-related bugs, convert dict to JSON string and parse JSON back to dict
Expand Down Expand Up @@ -441,8 +444,8 @@ def test_transform_success_resource_policy(self, testcase, partition_with_region
"samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call",
mock_sar_service_call,
)
@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region)
def test_transform_success_no_side_effect(self, testcase, partition_with_region):
@patch("samtranslator.translator.arn_generator._get_region_from_session")
def test_transform_success_no_side_effect(self, testcase, partition_with_region, mock_get_region_from_session):
"""
Tests that the transform does not leak/leave data in shared caches/lists between executions
Performs the transform of the templates in a row without reinitialization
Expand All @@ -457,6 +460,7 @@ def test_transform_success_no_side_effect(self, testcase, partition_with_region)
"""
partition = partition_with_region[0]
region = partition_with_region[1]
mock_get_region_from_session.return_value = region

for template in testcase[1]:
print(template, partition, region)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/translator/test_arn_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import TestCase

from unittest.mock import patch
from unittest.mock import Mock, patch
from parameterized import parameterized

from samtranslator.translator.arn_generator import ArnGenerator
Expand Down Expand Up @@ -31,5 +31,5 @@ def test_get_partition_name(self, region, expected_partition):
]
)
def test_get_partition_name_when_region_not_provided(self, region, expected_partition):
with patch("boto3.session.Session.region_name", region):
with patch("samtranslator.translator.arn_generator._get_region_from_session", Mock(return_value=region)):
self.assertEqual(expected_partition, ArnGenerator.get_partition_name())