Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ def record_set(
key_prefix = key_prefix.lstrip("/")
logger.debug("Uploading to bucket %s and key_prefix %s", bucket, key_prefix)
manifest_s3_file = upload_numpy_to_s3_shards(
self.instance_count, s3, bucket, key_prefix, train, labels, encrypt
self.instance_count, s3, bucket, key_prefix, train, labels, encrypt,
sagemaker_session=self.sagemaker_session
)
logger.debug("Created manifest file %s", manifest_s3_file)
return RecordSet(
Expand Down Expand Up @@ -455,7 +456,7 @@ def _build_shards(num_shards, array):


def upload_numpy_to_s3_shards(
num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False
num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False, sagemaker_session=None
):
"""Upload the training ``array`` and ``labels`` arrays to ``num_shards``.

Expand All @@ -470,6 +471,8 @@ def upload_numpy_to_s3_shards(
array:
labels:
encrypt:
sagemaker_session: Optional. SageMaker session used to resolve the
ExpectedBucketOwner spot check for the default bucket.
"""
shards = _build_shards(num_shards, array)
if labels is not None:
Expand All @@ -478,6 +481,12 @@ def upload_numpy_to_s3_shards(
if key_prefix[-1] != "/":
key_prefix = key_prefix + "/"
extra_put_kwargs = {"ServerSideEncryption": "AES256"} if encrypt else {}
# Spot check: enforce ownership only when uploading to the session's default
# bucket. Cross-account destinations are left untouched.
if sagemaker_session is not None:
expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket)
if expected_owner:
extra_put_kwargs["ExpectedBucketOwner"] = expected_owner
try:
for shard_index, shard in enumerate(shards):
with tempfile.TemporaryFile() as file:
Expand Down
18 changes: 15 additions & 3 deletions src/sagemaker/async_inference/async_inference_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def _get_result_from_s3_output_path(self, output_path):
"""Get inference result from the output Amazon S3 path"""
bucket, key = parse_s3_url(output_path)
try:
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
get_kwargs = {"Bucket": bucket, "Key": key}
expected_owner = self.predictor_async.sagemaker_session._get_account_id_if_default_bucket(bucket)
if expected_owner:
get_kwargs["ExpectedBucketOwner"] = expected_owner
response = self.predictor_async.s3_client.get_object(**get_kwargs)
return self.predictor_async.predictor._handle_response(response)
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
Expand All @@ -113,14 +117,22 @@ def _get_result_from_s3_output_failure_paths(self, output_path, failure_path):
"""Get inference result from the output & failure Amazon S3 path"""
bucket, key = parse_s3_url(output_path)
try:
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
get_kwargs = {"Bucket": bucket, "Key": key}
expected_owner = self.predictor_async.sagemaker_session._get_account_id_if_default_bucket(bucket)
if expected_owner:
get_kwargs["ExpectedBucketOwner"] = expected_owner
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This pattern of passing kwargs is repeated throughout the PR.
Is this the best way to do this?

Can we just pass ExpectedBucketOwner = None to the s3 client if the _get_account_id_if_default_bucket returns None?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous way of calling the s3 client looked way cleaner IMO.

response = self.predictor_async.s3_client.get_object(**get_kwargs)
return self.predictor_async.predictor._handle_response(response)
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
try:
failure_bucket, failure_key = parse_s3_url(failure_path)
fail_kwargs = {"Bucket": failure_bucket, "Key": failure_key}
fail_owner = self.predictor_async.sagemaker_session._get_account_id_if_default_bucket(failure_bucket)
if fail_owner:
fail_kwargs["ExpectedBucketOwner"] = fail_owner
failure_response = self.predictor_async.s3_client.get_object(
Bucket=failure_bucket, Key=failure_key
**fail_kwargs
)
failure_response = self.predictor_async.predictor._handle_response(
failure_response
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,9 @@ def _stage_user_code_in_s3(self) -> UploadedCode:
kms_key=kms_key,
s3_resource=self.sagemaker_session.s3_resource,
settings=self.sagemaker_session.settings,
expected_bucket_owner=self.sagemaker_session._get_account_id_if_default_bucket(
code_bucket
),
)

def _assign_s3_prefix(self, key_prefix=""):
Expand Down
35 changes: 32 additions & 3 deletions src/sagemaker/experiments/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ def upload_artifact(self, file_path, extra_args=None):
artifact_s3_key = "{}/{}/{}".format(
self.artifact_prefix, self.trial_component_name, artifact_name
)

# Spot check: enforce ownership only when uploading to the session's default
# bucket. Cross-account destinations are left untouched.
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
self.artifact_bucket
)
if expected_owner:
extra_args = dict(extra_args) if extra_args else {}
extra_args["ExpectedBucketOwner"] = expected_owner

self._s3_client.upload_file(
file_path,
self.artifact_bucket,
Expand Down Expand Up @@ -133,9 +143,21 @@ def upload_object_artifact(self, artifact_name, artifact_object, file_extension=
artifact_s3_key = "{}/{}/{}".format(
self.artifact_prefix, self.trial_component_name, artifact_name
)
self._s3_client.put_object(
Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key

# Spot check: enforce ownership only when uploading to the session's default
# bucket. Cross-account destinations are left untouched.
put_kwargs = {
"Body": json.dumps(artifact_object),
"Bucket": self.artifact_bucket,
"Key": artifact_s3_key,
}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
self.artifact_bucket
)
if expected_owner:
put_kwargs["ExpectedBucketOwner"] = expected_owner

self._s3_client.put_object(**put_kwargs)
etag = self._try_get_etag(artifact_s3_key)
return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag

Expand All @@ -149,7 +171,14 @@ def _try_get_etag(self, key):
str: The S3 object ETag if it allows, otherwise return None.
"""
try:
response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key)
head_kwargs = {"Bucket": self.artifact_bucket, "Key": key}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
self.artifact_bucket
)
if expected_owner:
head_kwargs["ExpectedBucketOwner"] = expected_owner

response = self._s3_client.head_object(**head_kwargs)
return response["ETag"]
except botocore.exceptions.ClientError as error:
# requires read permissions
Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def tar_and_upload_dir(
kms_key=None,
s3_resource=None,
settings: Optional[SessionSettings] = None,
expected_bucket_owner: Optional[str] = None,
) -> UploadedCode:
"""Package source files and upload a compress tar file to S3.

Expand Down Expand Up @@ -431,6 +432,12 @@ def tar_and_upload_dir(
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
of the SageMaker ``Session``, can be used to override the default encryption
behavior (default: None).
expected_bucket_owner (str): Optional. AWS account id passed as
``ExpectedBucketOwner`` on the upload. Callers should supply this when
``bucket`` is the session's default bucket (via
``Session._get_account_id_if_default_bucket``) to defend against
bucket-squatting on the predictable default name. Leave as ``None`` for
cross-account destination buckets.
Returns:
sagemaker.fw_utils.UploadedCode: An object with the S3 bucket and key (S3 prefix) and
script name.
Expand Down Expand Up @@ -472,6 +479,10 @@ def tar_and_upload_dir(
else:
extra_args = None

if expected_bucket_owner:
extra_args = dict(extra_args) if extra_args else {}
extra_args["ExpectedBucketOwner"] = expected_bucket_owner

if s3_resource is None:
s3_resource = session.resource("s3", region_name=session.region_name)
else:
Expand Down
36 changes: 34 additions & 2 deletions src/sagemaker/lambda_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,17 @@ def create(self):
bucket, key_prefix = s3.determine_bucket_and_prefix(
bucket=self.s3_bucket, key_prefix=None, sagemaker_session=self.session
)
# Spot check: if the resolved bucket is the session's default bucket,
# enforce ownership on the upload so an attacker cannot squat on the
# predictable default name.
expected_owner = self.session._get_account_id_if_default_bucket(bucket)
key = _upload_to_s3(
s3_client=_get_s3_client(self.session),
function_name=self.function_name,
zipped_code_dir=self.zipped_code_dir,
s3_bucket=bucket,
s3_key_prefix=key_prefix,
expected_bucket_owner=expected_owner,
)
code = {"S3Bucket": bucket, "S3Key": key}

Expand Down Expand Up @@ -179,6 +184,13 @@ def update(self):
else:
function_name_for_s3 = self.function_name

# Spot check: enforce ownership only when the resolved bucket is
# the session's default bucket (defends against squatting on the
# predictable default name). Other buckets are left untouched.
expected_owner = self.session._get_account_id_if_default_bucket(
bucket
)

response = lambda_client.update_function_code(
FunctionName=(self.function_name or self.function_arn),
S3Bucket=bucket,
Expand All @@ -188,6 +200,7 @@ def update(self):
zipped_code_dir=self.zipped_code_dir,
s3_bucket=bucket,
s3_key_prefix=key_prefix,
expected_bucket_owner=expected_owner,
),
)
return response
Expand Down Expand Up @@ -276,13 +289,29 @@ def _get_lambda_client(session):
return lambda_client


def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_prefix=None):
def _upload_to_s3(
s3_client,
function_name,
zipped_code_dir,
s3_bucket,
s3_key_prefix=None,
expected_bucket_owner=None,
):
"""Upload the zipped code to S3 bucket provided in the Lambda instance.

Lambda instance must have a path to the zipped code folder and a S3 bucket to upload
the code. The key will lambda/function_name/code and the S3 URI where the code is
uploaded is in this format: s3://bucket_name/lambda/function_name/code.

Args:
s3_client: boto3 S3 client used for the upload.
function_name (str): Lambda function name used to build the S3 key.
zipped_code_dir (str): Local path to the zipped Lambda code.
s3_bucket (str): Destination S3 bucket.
s3_key_prefix (str): Optional S3 key prefix.
expected_bucket_owner (str): Optional account id passed as ``ExpectedBucketOwner``
on the upload when the destination bucket should belong to that account.

Returns: the S3 key where the code is uploaded.
"""

Expand All @@ -292,7 +321,10 @@ def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_p
function_name,
"code",
)
s3_client.upload_file(zipped_code_dir, s3_bucket, key)
extra_args = None
if expected_bucket_owner:
extra_args = {"ExpectedBucketOwner": expected_bucket_owner}
s3_client.upload_file(zipped_code_dir, s3_bucket, key, ExtraArgs=extra_args)
return key


Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
dependencies=self.dependencies,
kms_key=self.model_kms_key,
settings=self.sagemaker_session.settings,
expected_bucket_owner=self.sagemaker_session._get_account_id_if_default_bucket(
bucket
),
)

if repack and self.model_data is not None and self.entry_point is not None:
Expand Down
12 changes: 11 additions & 1 deletion src/sagemaker/multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,17 @@ def add_model(self, model_data_source, model_data_path=None):
dst_s3_uri = s3.s3_path_join(dst_prefix, model_data_path)
else:
dst_s3_uri = s3.s3_path_join(dst_prefix, os.path.basename(model_data_source))
self.s3_client.upload_file(model_data_source, destination_bucket, dst_s3_uri)
# Spot check: enforce ownership only when uploading to the session's default
# bucket. Cross-account destinations are left untouched.
extra_args = None
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
destination_bucket
)
if expected_owner:
extra_args = {"ExpectedBucketOwner": expected_owner}
self.s3_client.upload_file(
model_data_source, destination_bucket, dst_s3_uri, ExtraArgs=extra_args
)
# return upload_path
return s3.s3_path_join("s3://", destination_bucket, dst_s3_uri)

Expand Down
37 changes: 31 additions & 6 deletions src/sagemaker/predictor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,18 @@ def _upload_data_to_s3(
)

data = self.serializer.serialize(data)
self.s3_client.put_object(
Body=data, Bucket=bucket, Key=key, ContentType=self.serializer.CONTENT_TYPE
)
# Spot check: enforce ownership only when uploading to the session's default
# bucket. Cross-account destinations are left untouched.
put_kwargs = {
"Body": data,
"Bucket": bucket,
"Key": key,
"ContentType": self.serializer.CONTENT_TYPE,
}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(bucket)
if expected_owner:
put_kwargs["ExpectedBucketOwner"] = expected_owner
self.s3_client.put_object(**put_kwargs)
input_path = input_path or "s3://{}/{}".format(bucket, key)

return input_path
Expand Down Expand Up @@ -241,7 +250,13 @@ def _check_output_path(self, output_path, waiter_config):
output_path=output_path,
seconds=waiter_config.delay * waiter_config.max_attempts,
)
s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
# Spot check: enforce ownership only when reading from the session's default
# bucket. Cross-account reads are left untouched.
get_kwargs = {"Bucket": bucket, "Key": key}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(bucket)
if expected_owner:
get_kwargs["ExpectedBucketOwner"] = expected_owner
s3_object = self.s3_client.get_object(**get_kwargs)
result = self.predictor._handle_response(response=s3_object)
return result

Expand Down Expand Up @@ -311,12 +326,22 @@ def check_failure_file():
time.sleep(1)

if output_file_found.is_set():
s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key)
get_kwargs = {"Bucket": output_bucket, "Key": output_key}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
output_bucket
)
if expected_owner:
get_kwargs["ExpectedBucketOwner"] = expected_owner
s3_object = self.s3_client.get_object(**get_kwargs)
result = self.predictor._handle_response(response=s3_object)
return result

if failure_file_found.is_set():
failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
fail_kwargs = {"Bucket": failure_bucket, "Key": failure_key}
fail_owner = self.sagemaker_session._get_account_id_if_default_bucket(failure_bucket)
if fail_owner:
fail_kwargs["ExpectedBucketOwner"] = fail_owner
failure_object = self.s3_client.get_object(**fail_kwargs)
failure_response = self.predictor._handle_response(response=failure_object)
raise AsyncInferenceModelError(message=failure_response)

Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,14 @@ def _create_recipe_copy(self, original_s3_uri):
# Copy the object with the new name
copy_source = {"Bucket": bucket, "Key": original_key}

s3_client.copy_object(CopySource=copy_source, Bucket=bucket, Key=new_key)
# Spot check: enforce ownership only when copying within the session's
# default bucket. Cross-account buckets are left untouched.
copy_kwargs = {"CopySource": copy_source, "Bucket": bucket, "Key": new_key}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(bucket)
if expected_owner:
copy_kwargs["ExpectedBucketOwner"] = expected_owner

s3_client.copy_object(**copy_kwargs)

return f"s3://{bucket}/{new_key}"

Expand Down
Loading
Loading