Skip to content

Commit

Permalink
[AIRFLOW-6884] Make SageMakerTrainingOperator idempotent (#7598)
Browse files Browse the repository at this point in the history
  • Loading branch information
BasPH committed Mar 12, 2020
1 parent 2327aa5 commit c0c5f11
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 3 deletions.
85 changes: 85 additions & 0 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import tempfile
import time
import warnings
from functools import partial
from typing import Dict, List, Optional

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -742,3 +744,86 @@ def check_training_status_with_log(self, job_name, non_terminal_states, failed_s
billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \
* instance_count
self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1)

def list_training_jobs(
self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs
) -> List[Dict]:
"""
This method wraps boto3's list_training_jobs(). The training job name and max results are configurable
via arguments. Other arguments are not, and should be provided via kwargs. Note boto3 expects these in
CamelCase format, for example:
.. code-block:: python
list_training_jobs(name_contains="myjob", StatusEquals="Failed")
.. seealso::
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_training_jobs
:param name_contains: (optional) partial name to match
:param max_results: (optional) maximum number of results to return. None returns infinite results
:param kwargs: (optional) kwargs to boto3's list_training_jobs method
:return: results of the list_training_jobs request
"""

config = dict()

if name_contains:
if "NameContains" in kwargs:
raise AirflowException("Either name_contains or NameContains can be provided, not both.")
config["NameContains"] = name_contains

if "MaxResults" in kwargs and kwargs["MaxResults"] is not None:
if max_results:
raise AirflowException("Either max_results or MaxResults can be provided, not both.")
# Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results
max_results = kwargs["MaxResults"]
del kwargs["MaxResults"]

config.update(kwargs)
list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config)
results = self._list_request(
list_training_jobs_request, "TrainingJobSummaries", max_results=max_results
)
return results

def _list_request(self, partial_func, result_key: str, max_results: Optional[int] = None) -> List[Dict]:
"""
All AWS boto3 list_* requests return results in batches (if the key "NextToken" is contained in the
result, there are more results to fetch). The default AWS batch size is 10, and configurable up to
100. This function iteratively loads all results (or up to a given maximum).
Each boto3 list_* function returns the results in a list with a different name. The key of this
structure must be given to iterate over the results, e.g. "TransformJobSummaries" for
list_transform_jobs().
:param partial_func: boto3 function with arguments
:param result_key: the result key to iterate over
:param max_results: maximum number of results to return (None = infinite)
:return: Results of the list_* request
"""

sagemaker_max_results = 100 # Fixed number set by AWS

results: List[Dict] = []
next_token = None

while True:
kwargs = dict()
if next_token is not None:
kwargs["NextToken"] = next_token

if max_results is None:
kwargs["MaxResults"] = sagemaker_max_results
else:
kwargs["MaxResults"] = min(max_results - len(results), sagemaker_max_results)

response = partial_func(**kwargs)
self.log.debug("Fetched %s results.", len(response[result_key]))
results.extend(response[result_key])

if "NextToken" not in response or (max_results is not None and len(results) == max_results):
# Return when there are no results left (no NextToken) or when we've reached max_results.
return results
else:
next_token = response["NextToken"]
31 changes: 28 additions & 3 deletions airflow/providers/amazon/aws/operators/sagemaker_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
the operation does not timeout.
:type max_ingestion_time: int
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
:type action_if_job_exists: str
"""

integer_fields = [
Expand All @@ -61,15 +64,23 @@ def __init__(self,
print_log=True,
check_interval=30,
max_ingestion_time=None,
action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
super().__init__(config=config, *args, **kwargs)

self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time

if action_if_job_exists in ("increment", "fail"):
self.action_if_job_exists = action_if_job_exists
else:
raise AirflowException(
"Argument action_if_job_exists accepts only 'increment' and 'fail'. "
f"Provided value: '{action_if_job_exists}'."
)

def expand_role(self):
if 'RoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id)
Expand All @@ -78,8 +89,22 @@ def expand_role(self):
def execute(self, context):
self.preprocess_config()

self.log.info('Creating SageMaker Training Job %s.', self.config['TrainingJobName'])
training_job_name = self.config["TrainingJobName"]
training_jobs = self.hook.list_training_jobs(name_contains=training_job_name)

# Check if given TrainingJobName already exists
if training_job_name in [tj["TrainingJobName"] for tj in training_jobs]:
if self.action_if_job_exists == "increment":
self.log.info("Found existing training job with name '%s'.", training_job_name)
new_training_job_name = f"{training_job_name}-{len(training_jobs) + 1}"
self.config["TrainingJobName"] = new_training_job_name
self.log.info("Incremented training job name to '%s'.", new_training_job_name)
elif self.action_if_job_exists == "fail":
raise AirflowException(
f"A SageMaker training job with name {training_job_name} already exists."
)

self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"])
response = self.hook.create_training_job(
self.config,
wait_for_completion=self.wait_for_completion,
Expand Down
33 changes: 33 additions & 0 deletions tests/providers/amazon/aws/operators/test_sagemaker_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,39 @@ def test_execute_with_failure(self, mock_training, mock_client):
self.assertRaises(AirflowException, self.sagemaker.execute, None)
# pylint: enable=unused-argument

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "list_training_jobs")
@mock.patch.object(SageMakerHook, "create_training_job")
def test_execute_with_existing_job_increment(
self, mock_create_training_job, mock_list_training_jobs, mock_client
):
self.sagemaker.action_if_job_exists = "increment"
mock_create_training_job.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}
mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
self.sagemaker.execute(None)

expected_config = create_training_params.copy()
# Expect to see TrainingJobName suffixed with "-2" because we return one existing job
expected_config["TrainingJobName"] = f"{job_name}-2"
mock_create_training_job.assert_called_once_with(
expected_config,
wait_for_completion=False,
print_log=True,
check_interval=5,
max_ingestion_time=None,
)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "list_training_jobs")
@mock.patch.object(SageMakerHook, "create_training_job")
def test_execute_with_existing_job_fail(
self, mock_create_training_job, mock_list_training_jobs, mock_client
):
self.sagemaker.action_if_job_exists = "fail"
mock_create_training_job.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}
mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
self.assertRaises(AirflowException, self.sagemaker.execute, None)


if __name__ == '__main__':
unittest.main()

0 comments on commit c0c5f11

Please sign in to comment.