Skip to content

Commit

Permalink
change: refactor tests to use common retry method (#1001)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenyu committed Aug 26, 2019
1 parent 4f00559 commit 0d74efb
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 40 deletions.
18 changes: 1 addition & 17 deletions tests/integ/file_system_input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from os import path
import stat
import tempfile
import time
import uuid

from botocore.exceptions import ClientError
from fabric import Connection

from tests.integ.retry import retries
from tests.integ.vpc_test_utils import check_or_create_vpc_resources_efs_fsx

VPC_NAME = "sagemaker-efs-fsx-vpc"
Expand All @@ -36,7 +36,6 @@
AMI_ID = "ami-082b5a644766e0e6f"
MIN_COUNT = 1
MAX_COUNT = 1
TIME_SLEEP_DURATION = 10

RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data")
MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist")
Expand Down Expand Up @@ -307,21 +306,6 @@ def _instance_profile_exists(sagemaker_session):
return True


def retries(max_retry_count, exception_message_prefix):
current_retry_count = 0
while current_retry_count <= max_retry_count:
yield current_retry_count

current_retry_count += 1
time.sleep(TIME_SLEEP_DURATION)

raise Exception(
"{} has reached the maximum retry count {}".format(
exception_message_prefix, max_retry_count
)
)


def tear_down(sagemaker_session, fs_resources):
fsx_client = sagemaker_session.boto_session.client("fsx")
file_system_fsx_id = fs_resources.file_system_fsx_id
Expand Down
29 changes: 29 additions & 0 deletions tests/integ/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import time

DEFAULT_SLEEP_TIME_SECONDS = 10


def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS):
for i in range(max_retry_count):
yield i
time.sleep(seconds_to_sleep)

raise Exception(
"{} has reached the maximum retry count {}".format(
exception_message_prefix, max_retry_count
)
)
13 changes: 4 additions & 9 deletions tests/integ/test_inference_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand All @@ -14,7 +14,6 @@

import json
import os
import time

import pytest
from tests.integ import DATA_DIR, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
Expand All @@ -30,6 +29,7 @@
from sagemaker.predictor import RealTimePredictor, json_serializer
from sagemaker.sparkml.model import SparkMLModel
from sagemaker.utils import sagemaker_timestamp
from tests.integ.retry import retries

SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model")
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")
Expand Down Expand Up @@ -190,16 +190,11 @@ def test_inference_pipeline_model_deploy_with_update_endpoint(
model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)

# Wait for endpoint to finish updating
max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
current_retry_count = 0
while current_retry_count <= max_retry_count:
if current_retry_count >= max_retry_count:
raise Exception("Endpoint status not 'InService' within expected timeout.")
time.sleep(30)
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
current_retry_count += 1
if new_endpoint["EndpointStatus"] == "InService":
break

Expand Down
12 changes: 4 additions & 8 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand All @@ -24,6 +24,7 @@
from sagemaker.utils import sagemaker_timestamp
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.kms_utils import get_or_create_kms_key
from tests.integ.retry import retries
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name


Expand Down Expand Up @@ -182,16 +183,11 @@ def test_deploy_model_with_update_endpoint(
model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)

# Wait for endpoint to finish updating
max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
current_retry_count = 0
while current_retry_count <= max_retry_count:
if current_retry_count >= max_retry_count:
raise Exception("Endpoint status not 'InService' within expected timeout.")
time.sleep(30)
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
current_retry_count += 1
if new_endpoint["EndpointStatus"] == "InService":
break

Expand Down
11 changes: 5 additions & 6 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tests.integ
from tests.integ import timeout
from tests.integ.retry import retries
from tests.integ.s3_utils import assert_s3_files_exist

ROLE = "SageMakerRole"
Expand Down Expand Up @@ -199,15 +200,13 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type):
assert expected_result == result


def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=15):
actual_tags = None
for _ in range(retries):
def _assert_tags_match(sagemaker_client, resource_arn, tags, retry_count=15):
# endpoint and training tags might take minutes to propagate.
for _ in retries(retry_count, "Getting endpoint tags", seconds_to_sleep=30):
actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"]
if actual_tags:
break
else:
# endpoint and training tags might take minutes to propagate. Sleeping.
time.sleep(30)

assert actual_tags == tags


Expand Down

0 comments on commit 0d74efb

Please sign in to comment.