Skip to content

Commit

Permalink
fix: use default bucket for checkpoint_s3_uri integ test(#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
chuyang-deng authored and Dan committed Oct 12, 2019
1 parent 8a6b784 commit d51792d
Showing 1 changed file with 10 additions and 30 deletions.
40 changes: 10 additions & 30 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest

from sagemaker.tensorflow import TensorFlow
from sagemaker.utils import unique_name_from_base
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp

import tests.integ
from tests.integ import timeout
Expand All @@ -39,7 +39,11 @@
TAGS = [{"Key": "some-key", "Value": "some-value"}]


def test_mnist(sagemaker_session, instance_type):
def test_mnist_with_checkpoint_config(sagemaker_session, instance_type):
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format(
sagemaker_session.default_bucket(), sagemaker_timestamp()
)
checkpoint_local_path = "/test/checkpoint/path"
estimator = TensorFlow(
entry_point=SCRIPT,
role="SageMakerRole",
Expand All @@ -50,13 +54,16 @@ def test_mnist(sagemaker_session, instance_type):
framework_version=TensorFlow.LATEST_VERSION,
py_version=tests.integ.PYTHON_VERSION,
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
)
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
)

training_job_name = unique_name_from_base("test-tf-sm-mnist")
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-mnist"))
estimator.fit(inputs=inputs, job_name=training_job_name)
assert_s3_files_exist(
sagemaker_session,
estimator.model_dir,
Expand All @@ -65,33 +72,6 @@ def test_mnist(sagemaker_session, instance_type):
df = estimator.training_job_analytics.dataframe()
assert df.size > 0


@pytest.mark.skipif(
tests.integ.test_region() != "us-east-1",
reason="checkpoint s3 bucket is in us-east-1, ListObjectsV2 will fail in other regions",
)
def test_checkpoint_config(sagemaker_session, instance_type):
checkpoint_s3_uri = "s3://142577830533-us-east-1-sagemaker-checkpoint"
checkpoint_local_path = "/test/checkpoint/path"
estimator = TensorFlow(
entry_point=SCRIPT,
role="SageMakerRole",
train_instance_count=1,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=TensorFlow.LATEST_VERSION,
py_version=tests.integ.PYTHON_VERSION,
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
)
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="script/mnist"
)
training_job_name = unique_name_from_base("test-tf-sm-checkpoint")
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=inputs, job_name=training_job_name)

expected_training_checkpoint_config = {
"S3Uri": checkpoint_s3_uri,
"LocalPath": checkpoint_local_path,
Expand Down

0 comments on commit d51792d

Please sign in to comment.