Skip to content

Commit

Permalink
Add support for Managed Spot Training and Checkpoint support (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaaq authored and knakad committed Aug 20, 2019
1 parent 3cf4f9b commit d0c6764
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 7 deletions.
85 changes: 85 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(
model_channel_name="model",
metric_definitions=None,
encrypt_inter_container_traffic=False,
train_use_spot_instances=False,
train_max_wait=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -157,6 +161,28 @@ def __init__(
encrypt_inter_container_traffic (bool): Specifies whether traffic
between training containers is encrypted for the training job
(default: ``False``).
train_use_spot_instances (bool): Specifies whether to use SageMaker
Managed Spot instances for training. If enabled then the
`train_max_wait` arg should also be set.
More information:
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
(default: ``False``).
train_max_wait (int): Timeout in seconds waiting for spot training
instances (default: None). After this amount of time Amazon
SageMaker will stop waiting for Spot instances to become
available (default: ``None``).
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
"""
self.role = role
self.train_instance_count = train_instance_count
Expand Down Expand Up @@ -199,6 +225,10 @@ def __init__(
self.security_group_ids = security_group_ids

self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
self.train_use_spot_instances = train_use_spot_instances
self.train_max_wait = train_max_wait
self.checkpoint_s3_uri = checkpoint_s3_uri
self.checkpoint_local_path = checkpoint_local_path

@abstractmethod
def train_image(self):
Expand Down Expand Up @@ -795,10 +825,35 @@ def start_new(cls, estimator, inputs):
else:
train_args["image"] = estimator.train_image()

cls._add_spot_checkpoint_args(local_mode, estimator, train_args)

estimator.sagemaker_session.train(**train_args)

return cls(estimator.sagemaker_session, estimator._current_job_name)

@classmethod
def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
"""
Args:
local_mode:
estimator:
train_args:
"""
if estimator.train_use_spot_instances:
if local_mode:
raise ValueError("Spot training is not supported in local mode.")
train_args["train_use_spot_instances"] = True

if estimator.checkpoint_s3_uri:
if local_mode:
raise ValueError("Setting checkpoint_s3_uri is not supported in local mode.")
train_args["checkpoint_s3_uri"] = estimator.checkpoint_s3_uri

if estimator.checkpoint_local_path:
if local_mode:
raise ValueError("Setting checkpoint_local_path is not supported in local mode.")
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path

@classmethod
def _is_local_channel(cls, input_uri):
"""
Expand Down Expand Up @@ -845,6 +900,10 @@ def __init__(
model_channel_name="model",
metric_definitions=None,
encrypt_inter_container_traffic=False,
train_use_spot_instances=False,
train_max_wait=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -926,6 +985,28 @@ def __init__(
encrypt_inter_container_traffic (bool): Specifies whether traffic
between training containers is encrypted for the training job
(default: ``False``).
train_use_spot_instances (bool): Specifies whether to use SageMaker
Managed Spot instances for training. If enabled then the
`train_max_wait` arg should also be set.
More information:
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
(default: ``False``).
train_max_wait (int): Timeout in seconds waiting for spot training
instances (default: None). After this amount of time Amazon
SageMaker will stop waiting for Spot instances to become
available (default: ``None``).
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
"""
self.image_name = image_name
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
Expand All @@ -948,6 +1029,10 @@ def __init__(
model_channel_name=model_channel_name,
metric_definitions=metric_definitions,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
train_use_spot_instances=train_use_spot_instances,
train_max_wait=train_max_wait,
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
)

def train_image(self):
Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
estimator.train_volume_size,
estimator.train_volume_kms_key,
)
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
stop_condition = _Job._prepare_stop_condition(
estimator.train_max_run, estimator.train_max_wait
)
vpc_config = estimator.get_vpc_config()

model_channel = _Job._prepare_channel(
Expand Down Expand Up @@ -312,11 +314,14 @@ def _prepare_resource_config(instance_count, instance_type, volume_size, train_v
return resource_config

@staticmethod
def _prepare_stop_condition(max_run):
def _prepare_stop_condition(max_run, max_wait):
"""
Args:
max_run:
max_wait:
"""
if max_wait:
return {"MaxRuntimeInSeconds": max_run, "MaxWaitTimeInSeconds": max_wait}
return {"MaxRuntimeInSeconds": max_run}

@property
Expand Down
37 changes: 33 additions & 4 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def train( # noqa: C901
image=None,
algorithm_arn=None,
encrypt_inter_container_traffic=False,
train_use_spot_instances=False,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
):
"""Create an Amazon SageMaker training job.
Expand Down Expand Up @@ -307,6 +310,18 @@ def train( # noqa: C901
algorithm_arn (str): Algorithm Arn from Marketplace.
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
containers is encrypted for the training job (default: ``False``).
train_use_spot_instances (bool): whether to use spot instances for training.
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
Returns:
str: ARN of the training job, if it is created.
Expand Down Expand Up @@ -357,6 +372,15 @@ def train( # noqa: C901
if encrypt_inter_container_traffic:
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic

if train_use_spot_instances:
train_request["EnableManagedSpotTraining"] = train_use_spot_instances

if checkpoint_s3_uri:
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
if checkpoint_local_path:
checkpoint_config["LocalPath"] = checkpoint_local_path
train_request["CheckpointConfig"] = checkpoint_config

LOGGER.info("Creating training-job with name: %s", job_name)
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
self.sagemaker_client.create_training_job(**train_request)
Expand Down Expand Up @@ -1468,10 +1492,15 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
print()
# Customers are not billed for hardware provisioning, so billable time is less than
# total time
billable_time = (
description["TrainingEndTime"] - description["TrainingStartTime"]
) * instance_count
print("Billable seconds:", int(billable_time.total_seconds()) + 1)
training_time = description.get("TrainingTimeInSeconds")
billable_time = description.get("BillableTimeInSeconds")
if training_time is not None:
print("Training seconds:", training_time * instance_count)
if billable_time is not None:
print("Billable seconds:", billable_time * instance_count)
if description.get("EnableManagedSpotTraining"):
saving = (1 - float(billable_time) / training_time) * 100
print("Managed Spot Training savings: {:.1f}%".format(saving))


def container_def(image, model_data_url=None, env=None):
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,69 @@ def test_framework_all_init_args(sagemaker_session):
}


def test_framework_with_spot_and_checkpoints(sagemaker_session):
f = DummyFramework(
"my_script.py",
role="DummyRole",
train_instance_count=3,
train_instance_type="ml.m4.xlarge",
sagemaker_session=sagemaker_session,
train_volume_size=123,
train_volume_kms_key="volumekms",
train_max_run=456,
input_mode="inputmode",
output_path="outputpath",
output_kms_key="outputkms",
base_job_name="basejobname",
tags=[{"foo": "bar"}],
subnets=["123", "456"],
security_group_ids=["789", "012"],
metric_definitions=[{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
encrypt_inter_container_traffic=True,
train_use_spot_instances=True,
train_max_wait=500,
checkpoint_s3_uri="s3://mybucket/checkpoints/",
checkpoint_local_path="/tmp/checkpoints",
)
_TrainingJob.start_new(f, "s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args == {
"input_mode": "inputmode",
"tags": [{"foo": "bar"}],
"hyperparameters": {},
"image": "fakeimage",
"input_config": [
{
"ChannelName": "training",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3DataDistributionType": "FullyReplicated",
"S3Uri": "s3://mydata",
}
},
}
],
"output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"},
"vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]},
"stop_condition": {"MaxRuntimeInSeconds": 456, "MaxWaitTimeInSeconds": 500},
"role": sagemaker_session.expand_role(),
"job_name": None,
"resource_config": {
"VolumeSizeInGB": 123,
"InstanceCount": 3,
"VolumeKmsKeyId": "volumekms",
"InstanceType": "ml.m4.xlarge",
},
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
"encrypt_inter_container_traffic": True,
"train_use_spot_instances": True,
"checkpoint_s3_uri": "s3://mybucket/checkpoints/",
"checkpoint_local_path": "/tmp/checkpoints",
}


def test_framework_init_s3_entry_point_invalid(sagemaker_session):
with pytest.raises(ValueError) as error:
DummyFramework(
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,22 @@ def test_prepare_resource_config_with_volume_kms():

def test_prepare_stop_condition():
max_run = 1
max_wait = 2

stop_condition = _Job._prepare_stop_condition(max_run)
stop_condition = _Job._prepare_stop_condition(max_run, max_wait)

assert stop_condition["MaxRuntimeInSeconds"] == max_run
assert stop_condition["MaxWaitTimeInSeconds"] == max_wait


def test_prepare_stop_condition_no_wait():
max_run = 1
max_wait = None

stop_condition = _Job._prepare_stop_condition(max_run, max_wait)

assert stop_condition["MaxRuntimeInSeconds"] == max_run
assert "MaxWaitTimeInSeconds" not in stop_condition


def test_name(sagemaker_session):
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
tags=TAGS,
metric_definitions=METRIC_DEFINITONS,
encrypt_inter_container_traffic=True,
train_use_spot_instances=True,
checkpoint_s3_uri="s3://mybucket/checkpoints/",
checkpoint_local_path="/tmp/checkpoints",
)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
Expand All @@ -660,6 +663,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
assert actual_train_args["Tags"] == TAGS
assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS
assert actual_train_args["EnableInterContainerTrafficEncryption"] is True
assert actual_train_args["EnableManagedSpotTraining"] is True
assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/"
assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"


def test_transform_pack_to_request(sagemaker_session):
Expand Down

0 comments on commit d0c6764

Please sign in to comment.