Skip to content

Commit

Permalink
feature: handler for stopping transform job (#850)
Browse files Browse the repository at this point in the history
  • Loading branch information
imujjwal96 authored and laurenyu committed Sep 4, 2019
1 parent c701100 commit 379ceac
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,27 @@ def wait_for_transform_job(self, job, poll=5):
self._check_job_status(job, desc, "TransformJobStatus")
return desc

def stop_transform_job(self, name):
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
Args:
name (str): Name of the Amazon SageMaker batch transform job.
Raises:
ClientError: If an error occurs while trying to stop the batch transform job.
"""
try:
LOGGER.info("Stopping transform job: %s", name)
self.sagemaker_client.stop_transform_job(TransformJobName=name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
# allow to pass if the job already stopped
if error_code == "ValidationException":
LOGGER.info("Transform job: %s is already stopped or not running.", name)
else:
LOGGER.error("Error occurred while attempting to stop transform job: %s.", name)
raise

def _check_job_status(self, job, desc, status_key_name):
"""Check to see if the job completed successfully and, if not, construct and
raise a exceptions.UnexpectedStatusException.
Expand Down
12 changes: 12 additions & 0 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def wait(self):
self._ensure_last_transform_job()
self.latest_transform_job.wait()

def stop_transform_job(self, wait=True):
"""Stop latest running batch transform job.
"""
self._ensure_last_transform_job()
self.latest_transform_job.stop()
if wait:
self.latest_transform_job.wait()

def _ensure_last_transform_job(self):
"""Placeholder docstring"""
if self.latest_transform_job is None:
Expand Down Expand Up @@ -346,6 +354,10 @@ def start_new(
def wait(self):
self.sagemaker_session.wait_for_transform_job(self.job_name)

def stop(self):
"""Placeholder docstring"""
self.sagemaker_session.stop_transform_job(name=self.job_name)

@staticmethod
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):
"""
Expand Down
49 changes: 49 additions & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import pickle
import sys
import time

import pytest

Expand Down Expand Up @@ -349,6 +350,54 @@ def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version,
)


def test_stop_transform_job(sagemaker_session, mxnet_full_version):
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
script_path = os.path.join(data_path, "mnist.py")
tags = [{"Key": "some-tag", "Value": "value-for-tag"}]

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
train_instance_count=1,
train_instance_type="ml.c4.xlarge",
sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version,
)

train_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
)
test_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
)
job_name = unique_name_from_base("test-mxnet-transform")

with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)

transform_input_path = os.path.join(data_path, "transform", "data.csv")
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
transform_input = mx.sagemaker_session.upload_data(
path=transform_input_path, key_prefix=transform_input_key_prefix
)

transformer = mx.transformer(1, "ml.m4.xlarge", tags=tags)
transformer.transform(transform_input, content_type="text/csv")

time.sleep(15)

latest_transform_job_name = transformer.latest_transform_job.name

print("Attempting to stop {}".format(latest_transform_job_name))

transformer.stop_transform_job()

desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=latest_transform_job_name
)
assert desc["TransformJobStatus"] == "Stopped"


def _create_transformer_and_transform_job(
estimator,
transform_input,
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,18 @@ def test_restart_output_path(start_new_job, transformer, sagemaker_session):

transformer.transform(DATA, job_name="job-2")
assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-2")


def test_stop_transform_job(sagemaker_session, transformer):
sagemaker_session.stop_transform_job = Mock(name="stop_transform_job")
transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME)

transformer.stop_transform_job()

sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)


def test_stop_transform_job_no_transform_job(transformer):
with pytest.raises(ValueError) as e:
transformer.stop_transform_job()
assert "No transform job available" in str(e)

0 comments on commit 379ceac

Please sign in to comment.