Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,9 @@ class LocalSession(Session):
:class:`~sagemaker.session.Session`.
"""

def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False):
def __init__(
self, boto_session=None, s3_endpoint_url=None, disable_local_code=False, default_bucket=None
):
"""Create a Local SageMaker Session.

Args:
Expand All @@ -503,14 +505,20 @@ def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=F
disable_local_code (bool): Set ``True`` to override the default AWS configuration
chain to disable the ``local.local_code`` setting, which may not be supported for
some SDK features (default: False).
default_bucket (str): The default Amazon S3 bucket to be used by this session.
This will be created the next time an Amazon S3 bucket is needed (by calling
:func:`default_bucket`).
If not provided, a default bucket will be created based on the following format:
"sagemaker-{region}-{aws-account-id}".
Example: "sagemaker-my-custom-bucket".
"""
self.s3_endpoint_url = s3_endpoint_url
# We use this local variable to avoid disrupting the __init__->_initialize API of the
# parent class... But overwriting it after constructor won't do anything, so prefix _ to
# discourage external use:
self._disable_local_code = disable_local_code

super(LocalSession, self).__init__(boto_session)
super(LocalSession, self).__init__(boto_session, default_bucket=default_bucket)

if platform.system() == "Windows":
logger.warning("Windows Support for Local Mode is Experimental")
Expand Down