Skip to content

Commit

Permalink
change: use regional endpoint when creating AWS STS client (#1026)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenyu committed Sep 6, 2019
1 parent ded3c8f commit 228a81d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 10 deletions.
8 changes: 6 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
name_from_image,
secondary_training_status_changed,
secondary_training_status_message,
sts_regional_endpoint,
)
from sagemaker import exceptions

Expand Down Expand Up @@ -1377,10 +1378,13 @@ def expand_role(self, role):

def get_caller_identity_arn(self):
"""Returns the ARN user or role whose credentials are used to call the API.
Returns:
(str): The ARN user or role
str: The ARN user or role
"""
assumed_role = self.boto_session.client("sts").get_caller_identity()["Arn"]
assumed_role = self.boto_session.client(
"sts", endpoint_url=sts_regional_endpoint(self.boto_region_name)
).get_caller_identity()["Arn"]

if "AmazonSageMaker-ExecutionRole" in assumed_role:
role = re.sub(
Expand Down
20 changes: 19 additions & 1 deletion src/sagemaker/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -538,6 +538,24 @@ def get_ecr_image_uri_prefix(account, region):
return "{}.dkr.ecr.{}.{}".format(account, region, domain)


def sts_regional_endpoint(region):
"""Get the AWS STS endpoint specific for the given region.
We need this function because the AWS SDK does not yet honor
the ``region_name`` parameter when creating an AWS STS client.
For the list of regional endpoints, see
https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
Args:
region (str): AWS region name
Returns:
str: AWS STS regional endpoint
"""
return "sts.{}.amazonaws.com".format(region)


class DeferredError(object):
"""Stores an exception and raises it at a later time if this object is
accessed in any way. Useful to allow soft-dependencies on imports, so that
Expand Down
25 changes: 18 additions & 7 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -31,6 +31,7 @@
SAMPLE_PARAM_RANGES = [{"Name": "mini_batch_size", "MinValue": "10", "MaxValue": "100"}]

REGION = "us-west-2"
STS_ENDPOINT = "sts.us-west-2.amazonaws.com"


@pytest.fixture()
Expand Down Expand Up @@ -88,7 +89,9 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
def test_get_caller_identity_arn_from_an_user(boto_session):
sess = Session(boto_session)
arn = "arn:aws:iam::369233609183:user/mia"
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": arn
}
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}}

actual = sess.get_caller_identity_arn()
Expand All @@ -98,7 +101,9 @@ def test_get_caller_identity_arn_from_an_user(boto_session):
def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
sess = Session(boto_session)
arn = "arn:aws:iam::369233609183:user/mia"
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": arn
}
sess.boto_session.client("iam").get_role.side_effect = ClientError({}, {})

with patch("logging.Logger.warning") as mock_logger:
Expand All @@ -112,7 +117,9 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
arn = (
"arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122"
)
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": arn
}

expected_role = "arn:aws:iam::369233609183:role/SageMakerRole"
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": expected_role}}
Expand All @@ -124,7 +131,9 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
def test_get_caller_identity_arn_from_a_execution_role(boto_session):
sess = Session(boto_session)
arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker"
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": arn
}
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}}

actual = sess.get_caller_identity_arn()
Expand All @@ -138,7 +147,7 @@ def test_get_caller_identity_arn_from_role_with_path(boto_session):
sess = Session(boto_session)
arn_prefix = "arn:aws:iam::369233609183:role"
role_name = "name"
sess.boto_session.client("sts").get_caller_identity.return_value = {
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": "/".join([arn_prefix, role_name])
}

Expand Down Expand Up @@ -344,7 +353,9 @@ def test_s3_input_all_arguments():
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session")
boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"}
boto_mock.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Account": "123"
}
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
return ims
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,8 @@ def walk():

result = set(walk())
return result if result else {}


def test_sts_regional_endpoint():
endpoint = sagemaker.utils.sts_regional_endpoint("us-west-2")
assert endpoint == "sts.us-west-2.amazonaws.com"

0 comments on commit 228a81d

Please sign in to comment.