Skip to content

Commit

Permalink
fix: take checkpoint_s3_uri and checkpoint_local_path in Framework cl…
Browse files Browse the repository at this point in the history
…ass (#1080)
  • Loading branch information
chuyang-deng authored and Dan committed Oct 10, 2019
1 parent 64834c2 commit 135171f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,8 @@ def __init__(
dependencies=None,
enable_network_isolation=False,
git_config=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
**kwargs
):
"""Base class initializer. Subclasses which override ``__init__`` should
Expand Down Expand Up @@ -1363,6 +1365,17 @@ def __init__(
authentication if they are provided; otherwise, python SDK will
try to use either CodeCommit credential helper or local
credential storage for authentication.
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
**kwargs: Additional kwargs passed to the ``EstimatorBase``
constructor.
"""
Expand Down Expand Up @@ -1391,6 +1404,8 @@ def __init__(
self.uploaded_code = None

self._hyperparameters = hyperparameters or {}
self.checkpoint_s3_uri = checkpoint_s3_uri
self.checkpoint_local_path = checkpoint_local_path

def enable_network_isolation(self):
"""Return True if this Estimator can use network isolation to run.
Expand Down
36 changes: 36 additions & 0 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,42 @@ def test_mnist(sagemaker_session, instance_type):
assert df.size > 0


@pytest.mark.skipif(
tests.integ.test_region() != "us-east-1",
reason="checkpoint s3 bucket is in us-east-1, ListObjectsV2 will fail in other regions",
)
def test_checkpoint_config(sagemaker_session, instance_type):
checkpoint_s3_uri = "s3://142577830533-us-east-1-sagemaker-checkpoint"
checkpoint_local_path = "/test/checkpoint/path"
estimator = TensorFlow(
entry_point=SCRIPT,
role="SageMakerRole",
train_instance_count=1,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=TensorFlow.LATEST_VERSION,
py_version=tests.integ.PYTHON_VERSION,
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
)
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="script/mnist"
)
training_job_name = unique_name_from_base("test-tf-sm-checkpoint")
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=inputs, job_name=training_job_name)

expected_training_checkpoint_config = {
"S3Uri": checkpoint_s3_uri,
"LocalPath": checkpoint_local_path,
}
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)["CheckpointConfig"]
assert actual_training_checkpoint_config == expected_training_checkpoint_config


def test_server_side_encryption(sagemaker_session):
boto_session = sagemaker_session.boto_session
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def test_framework_all_init_args(sagemaker_session):
security_group_ids=["789", "012"],
metric_definitions=[{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
encrypt_inter_container_traffic=True,
checkpoint_s3_uri="s3://bucket/checkpoint",
checkpoint_local_path="file://local/checkpoint",
)
_TrainingJob.start_new(f, "s3://mydata")
sagemaker_session.train.assert_called_once()
Expand Down Expand Up @@ -237,6 +239,8 @@ def test_framework_all_init_args(sagemaker_session):
},
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
"encrypt_inter_container_traffic": True,
"checkpoint_s3_uri": "s3://bucket/checkpoint",
"checkpoint_local_path": "file://local/checkpoint",
}


Expand Down

0 comments on commit 135171f

Please sign in to comment.