Skip to content

Commit

Permalink
feature: Estimator.fit like logs for transformer (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
imujjwal96 authored and laurenyu committed Sep 6, 2019
1 parent f54f506 commit c364fd1
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 66 deletions.
236 changes: 174 additions & 62 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,24 +1428,12 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method

description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
print(secondary_training_status_message(description, None), end="")
instance_count = description["ResourceConfig"]["InstanceCount"]
status = description["TrainingJobStatus"]

stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position

# Increase retries allowed (from default of 4), as we don't want waiting for a training job
# to be interrupted by a transient exception.
config = botocore.config.Config(retries={"max_attempts": 15})
client = self.boto_session.client("logs", config=config)
log_group = "/aws/sagemaker/TrainingJobs"

job_already_completed = status in ("Completed", "Failed", "Stopped")

state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
dot = False
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
self, description, job="Training"
)

color_wrap = sagemaker.logs.ColorWrap()
state = _get_initial_job_state(description, "TrainingJobStatus", wait)

# The loop below implements a state machine that alternates between checking the job status
# and reading whatever is available in the logs at this point. Note, that if we were
Expand All @@ -1470,52 +1458,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
last_describe_job_call = time.time()
last_description = description
while True:
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so
# this list # may be dynamic until we have a stream for every instance.
try:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=instance_count,
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
positions.update(
[
(s, sagemaker.logs.Position(timestamp=0, skip=0))
for s in stream_names
if s not in positions
]
)
except ClientError as e:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
err = e.response.get("Error", {})
if err.get("Code", None) != "ResourceNotFoundException":
raise

if len(stream_names) > 0:
if dot:
print("")
dot = False
for idx, event in sagemaker.logs.multi_stream_iter(
client, log_group, stream_names, positions
):
color_wrap(idx, event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = sagemaker.logs.Position(
timestamp=ts, skip=count + 1
)
else:
positions[stream_names[idx]] = sagemaker.logs.Position(
timestamp=event["timestamp"], skip=1
)
else:
dot = True
print(".", end="")
sys.stdout.flush()
_flush_log_streams(
stream_names,
instance_count,
client,
log_group,
job_name,
positions,
dot,
color_wrap,
)
if state == LogState.COMPLETE:
break

Expand Down Expand Up @@ -1554,6 +1506,86 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
saving = (1 - float(billable_time) / training_time) * 100
print("Managed Spot Training savings: {:.1f}%".format(saving))

def logs_for_transform_job(self, job_name, wait=False, poll=10):
"""Display the logs for a given transform job, optionally tailing them until the
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
based on which instance the log entry is from.
Args:
job_name (str): Name of the transform job to display the logs for.
wait (bool): Whether to keep looking for new log entries until the job completes
(default: False).
poll (int): The interval in seconds between polling for new log entries and job
completion (default: 5).
Raises:
ValueError: If the transform job fails.
"""

description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)

instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
self, description, job="Transform"
)

state = _get_initial_job_state(description, "TransformJobStatus", wait)

# The loop below implements a state machine that alternates between checking the job status
# and reading whatever is available in the logs at this point. Note, that if we were
# called with wait == False, we never check the job status.
#
# If wait == TRUE and job is not completed, the initial state is TAILING
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
# complete).
#
# The state table:
#
# STATE ACTIONS CONDITION NEW STATE
# ---------------- ---------------- ----------------- ----------------
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
# Else TAILING
# JOB_COMPLETE Read logs, Pause Any COMPLETE
# COMPLETE Read logs, Exit N/A
#
# Notes:
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
# Cloudwatch after the job was marked complete.
last_describe_job_call = time.time()
while True:
_flush_log_streams(
stream_names,
instance_count,
client,
log_group,
job_name,
positions,
dot,
color_wrap,
)
if state == LogState.COMPLETE:
break

time.sleep(poll)

if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = self.sagemaker_client.describe_transform_job(
TransformJobName=job_name
)
last_describe_job_call = time.time()

status = description["TransformJobStatus"]

if status in ("Completed", "Failed", "Stopped"):
print()
state = LogState.JOB_COMPLETE

if wait:
self._check_job_status(job_name, description, "TransformJobStatus")
if dot:
print()


def container_def(image, model_data_url=None, env=None):
"""Create a definition for executing a container as part of a SageMaker model.
Expand Down Expand Up @@ -1892,3 +1924,83 @@ def _vpc_config_from_training_job(
if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT:
return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY)
return vpc_utils.sanitize(vpc_config_override)


def _get_initial_job_state(description, status_key, wait):
"""Placeholder docstring"""
status = description[status_key]
job_already_completed = status in ("Completed", "Failed", "Stopped")
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE


def _logs_init(sagemaker_session, description, job):
"""Placeholder docstring"""
if job == "Training":
instance_count = description["ResourceConfig"]["InstanceCount"]
elif job == "Transform":
instance_count = description["TransformResources"]["InstanceCount"]

stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position

# Increase retries allowed (from default of 4), as we don't want waiting for a training job
# to be interrupted by a transient exception.
config = botocore.config.Config(retries={"max_attempts": 15})
client = sagemaker_session.boto_session.client("logs", config=config)
log_group = "/aws/sagemaker/" + job + "Jobs"

dot = False

color_wrap = sagemaker.logs.ColorWrap()

return instance_count, stream_names, positions, client, log_group, dot, color_wrap


def _flush_log_streams(
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
):
"""Placeholder docstring"""
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
# may be dynamic until we have a stream for every instance.
try:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=instance_count,
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
positions.update(
[
(s, sagemaker.logs.Position(timestamp=0, skip=0))
for s in stream_names
if s not in positions
]
)
except ClientError as e:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
err = e.response.get("Error", {})
if err.get("Code", None) != "ResourceNotFoundException":
raise

if len(stream_names) > 0:
if dot:
print("")
dot = False
for idx, event in sagemaker.logs.multi_stream_iter(
client, log_group, stream_names, positions
):
color_wrap(idx, event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = sagemaker.logs.Position(
timestamp=event["timestamp"], skip=1
)
else:
dot = True
print(".", end="")
sys.stdout.flush()
20 changes: 16 additions & 4 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def transform(
input_filter=None,
output_filter=None,
join_source=None,
wait=False,
logs=False,
):
"""Start a new transform job.
Expand Down Expand Up @@ -154,6 +156,10 @@ def transform(
will be joined to the inference result. You can use OutputFilter
to select the useful portion before uploading to S3. (default:
None). Valid values: Input, None.
wait (bool): Whether the call should wait until the job completes
(default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when wait is True (default: False).
"""
local_mode = self.sagemaker_session.local_mode
if not local_mode and not data.startswith("s3://"):
Expand Down Expand Up @@ -187,6 +193,9 @@ def transform(
join_source,
)

if wait:
self.latest_transform_job.wait(logs=logs)

def delete_model(self):
"""Delete the corresponding SageMaker model for this Transformer."""
self.sagemaker_session.delete_model(self.model_name)
Expand Down Expand Up @@ -224,10 +233,10 @@ def _retrieve_image_name(self):
"Local instance types require locally created models." % self.model_name
)

def wait(self):
def wait(self, logs=True):
"""Placeholder docstring"""
self._ensure_last_transform_job()
self.latest_transform_job.wait()
self.latest_transform_job.wait(logs=logs)

def stop_transform_job(self, wait=True):
"""Stop latest running batch transform job.
Expand Down Expand Up @@ -351,8 +360,11 @@ def start_new(

return cls(transformer.sagemaker_session, transformer._current_job_name)

def wait(self):
self.sagemaker_session.wait_for_transform_job(self.job_name)
def wait(self, logs=True):
if logs:
self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True)
else:
self.sagemaker_session.wait_for_transform_job(self.job_name)

def stop(self):
"""Placeholder docstring"""
Expand Down
45 changes: 45 additions & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,47 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version, cpu_instance_
assert desc["TransformJobStatus"] == "Stopped"


def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version, cpu_instance_type):
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
script_path = os.path.join(data_path, "mnist.py")

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
train_instance_count=1,
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version,
)

train_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
)
test_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
)
job_name = unique_name_from_base("test-mxnet-transform")

with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)

transform_input_path = os.path.join(data_path, "transform", "data.csv")
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
transform_input = mx.sagemaker_session.upload_data(
path=transform_input_path, key_prefix=transform_input_key_prefix
)

with timeout(minutes=45):
transformer = _create_transformer_and_transform_job(
mx, transform_input, cpu_instance_type, wait=True, logs=True
)

with timeout_and_delete_model_with_transformer(
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
):
transformer.wait()


def _create_transformer_and_transform_job(
estimator,
transform_input,
Expand All @@ -406,6 +447,8 @@ def _create_transformer_and_transform_job(
input_filter=None,
output_filter=None,
join_source=None,
wait=False,
logs=False,
):
transformer = estimator.transformer(1, instance_type, volume_kms_key=volume_kms_key)
transformer.transform(
Expand All @@ -414,5 +457,7 @@ def _create_transformer_and_transform_job(
input_filter=input_filter,
output_filter=output_filter,
join_source=join_source,
wait=wait,
logs=logs,
)
return transformer

0 comments on commit c364fd1

Please sign in to comment.