Skip to content

Commit

Permalink
fix: make get_caller_identity_arn get role from DescribeNotebookInsta…
Browse files Browse the repository at this point in the history
…nce (#1033)

Add an initial attempt to get the role via DescribeNotebookInstance.
If that attempt fails fallback to the current heuristics-based behavior.
  • Loading branch information
Davidhw authored and laurenyu committed Sep 9, 2019
1 parent bae66a0 commit de676a1
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

LOGGER = logging.getLogger("sagemaker")

NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"

_STATUS_CODE_TABLE = {
"COMPLETED": "Completed",
"INPROGRESS": "InProgress",
Expand Down Expand Up @@ -1382,6 +1384,21 @@ def get_caller_identity_arn(self):
Returns:
str: The ARN user or role
"""
if os.path.exists(NOTEBOOK_METADATA_FILE):
with open(NOTEBOOK_METADATA_FILE, "rb") as f:
instance_name = json.loads(f.read())["ResourceName"]
try:
instance_desc = self.sagemaker_client.describe_notebook_instance(
NotebookInstanceName=instance_name
)
return instance_desc["RoleArn"]
except ClientError:
LOGGER.warning(
"Couldn't call 'describe_notebook_instance' to get the Role "
"ARN of the instance %s.",
instance_name,
)

assumed_role = self.boto_session.client(
"sts", endpoint_url=sts_regional_endpoint(self.boto_region_name)
).get_caller_identity()["Arn"]
Expand Down
69 changes: 68 additions & 1 deletion tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import io
import logging
import os

import pytest
import six
Expand All @@ -23,7 +24,12 @@

import sagemaker
from sagemaker import s3_input, Session, get_execution_role
from sagemaker.session import _tuning_job_status, _transform_job_status, _train_done
from sagemaker.session import (
_tuning_job_status,
_transform_job_status,
_train_done,
NOTEBOOK_METADATA_FILE,
)
from sagemaker.tuner import WarmStartConfig, WarmStartTypes

STATIC_HPs = {"feature_dim": "784"}
Expand All @@ -47,6 +53,18 @@ def boto_session():
return boto_session


def mock_exists(filepath_to_mock, exists_result):
unmocked_exists = os.path.exists

def side_effect(filepath):
if filepath == filepath_to_mock:
return exists_result
else:
return unmocked_exists(filepath)

return Mock(side_effect=side_effect)


def test_get_execution_role():
session = Mock()
session.get_caller_identity_arn.return_value = "arn:aws:iam::369233609183:role/SageMakerRole"
Expand Down Expand Up @@ -86,6 +104,51 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
assert "ValueError: The current AWS identity is not a role" in str(error)


@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}'))
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
def test_get_caller_identity_arn_from_describe_notebook_instance(boto_session):
sess = Session(boto_session)
expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
sess.sagemaker_client.describe_notebook_instance.return_value = {"RoleArn": expected_role}

actual = sess.get_caller_identity_arn()

assert actual == expected_role
sess.sagemaker_client.describe_notebook_instance.assert_called_once_with(
NotebookInstanceName="SageMakerInstance"
)


@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}'))
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
def test_get_caller_identity_arn_from_a_role_after_describe_notebook_exception(boto_session):
sess = Session(boto_session)
exception = ClientError(
{"Error": {"Code": "ValidationException", "Message": "RecordNotFound"}}, "Operation"
)
sess.sagemaker_client.describe_notebook_instance.side_effect = exception

arn = (
"arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122"
)
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": arn
}

expected_role = "arn:aws:iam::369233609183:role/SageMakerRole"
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": expected_role}}

with patch("logging.Logger.warning") as mock_logger:
actual = sess.get_caller_identity_arn()
mock_logger.assert_called_once()

sess.sagemaker_client.describe_notebook_instance.assert_called_once_with(
NotebookInstanceName="SageMakerInstance"
)
assert actual == expected_role


@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
def test_get_caller_identity_arn_from_an_user(boto_session):
sess = Session(boto_session)
arn = "arn:aws:iam::369233609183:user/mia"
Expand All @@ -98,6 +161,7 @@ def test_get_caller_identity_arn_from_an_user(boto_session):
assert actual == "arn:aws:iam::369233609183:user/mia"


@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
sess = Session(boto_session)
arn = "arn:aws:iam::369233609183:user/mia"
Expand All @@ -112,6 +176,7 @@ def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
mock_logger.assert_called_once()


@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
def test_get_caller_identity_arn_from_a_role(boto_session):
sess = Session(boto_session)
arn = (
Expand All @@ -128,6 +193,7 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
assert actual == expected_role


@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
def test_get_caller_identity_arn_from_a_execution_role(boto_session):
sess = Session(boto_session)
arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker"
Expand All @@ -143,6 +209,7 @@ def test_get_caller_identity_arn_from_a_execution_role(boto_session):
)


@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
def test_get_caller_identity_arn_from_role_with_path(boto_session):
sess = Session(boto_session)
arn_prefix = "arn:aws:iam::369233609183:role"
Expand Down

0 comments on commit de676a1

Please sign in to comment.