From b941b9f8d1281982844fa3502a6f73e873113143 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Mon, 4 Oct 2021 12:59:05 +0200 Subject: [PATCH 1/2] Add timeout to wait_for_endpoint and _wait_until. This is useful for cases in which and endpoint takes unusually long to come online. --- src/sagemaker/session.py | 38 +++++++++++++++----- tests/unit/test_exception_on_bad_status.py | 41 +++++++++++++++++----- 2 files changed, 62 insertions(+), 17 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 91b89ea4c9..8dce301600 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3292,17 +3292,22 @@ def _check_job_status(self, job, desc, status_key_name): actual_status=status, ) - def wait_for_endpoint(self, endpoint, poll=30): + def wait_for_endpoint(self, endpoint, poll=30, timeout_seconds=1800.0): """Wait for an Amazon SageMaker endpoint deployment to complete. Args: endpoint (str): Name of the ``Endpoint`` to wait for. - poll (int): Polling interval in seconds (default: 5). + poll (float): Polling interval in seconds (default: 30). + timeout_seconds (float): Timeout in seconds (default: 1800). Returns: - dict: Return value from the ``DescribeEndpoint`` API. + dict: Return value from the ``DescribeEndpoint`` API or None if timeout_seconds passed """ - desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll) + desc = _wait_until( + lambda: _deploy_done(self.sagemaker_client, endpoint), poll, timeout_seconds + ) + if not desc: + return desc status = desc["EndpointStatus"] if status != "InService": @@ -4658,12 +4663,29 @@ def _wait_until_training_done(callable_fn, desc, poll=5): return job_desc -def _wait_until(callable_fn, poll=5): - """Placeholder docstring""" +def _wait_until(callable_fn, poll_seconds=5, timeout_seconds=None): + """ + Args: + callable_fn: callable to wait for which returns None to keep polling + poll_seconds (float): time to sleep between calls to callable_fn + timeout_seconds (float): Optional stop polling after timeout_seconds elapsed. + + Returns: + Result of the callable_fn + """ + waited_seconds = 0.0 + last_time = time.time() result = callable_fn() - while result is None: - time.sleep(poll) + waited_seconds += time.time() - last_time + last_time = time.time() + while result is None and timeout_seconds and waited_seconds < timeout_seconds: + sleep_s = ( + min(poll_seconds, timeout_seconds - waited_seconds) if timeout_seconds else poll_seconds + ) + time.sleep(sleep_s) result = callable_fn() + waited_seconds += time.time() - last_time + last_time = time.time() return result diff --git a/tests/unit/test_exception_on_bad_status.py b/tests/unit/test_exception_on_bad_status.py index 1eaa832125..fc9fc9d28b 100644 --- a/tests/unit/test_exception_on_bad_status.py +++ b/tests/unit/test_exception_on_bad_status.py @@ -13,8 +13,9 @@ from __future__ import absolute_import import pytest -from mock import Mock, MagicMock +from mock import Mock, MagicMock, DEFAULT import sagemaker +import time EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" REGION = "us-west-2" @@ -23,13 +24,23 @@ ENDPOINT_NAME = "the_point_of_end" -def get_sagemaker_session(returns_status): +def get_sagemaker_session_mock_endpoint_status(returns_status, block_seconds=None): boto_mock = MagicMock(name="boto_session", region_name=REGION) client_mock = MagicMock() client_mock.describe_model_package = MagicMock( return_value={"ModelPackageStatus": returns_status} ) - client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status}) + side_effect = None + + def side_effect_fn(*args, **kwargs): + time.sleep(block_seconds) + return DEFAULT + + if block_seconds: + side_effect = side_effect_fn + client_mock.describe_endpoint = MagicMock( + return_value={"EndpointStatus": returns_status}, side_effect=side_effect + ) ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock) ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims @@ -37,7 +48,7 @@ def get_sagemaker_session(returns_status): def test_does_not_raise_when_successfully_created_package(): try: - sagemaker_session = get_sagemaker_session(returns_status="Completed") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Completed") sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME) except sagemaker.exceptions.UnexpectedStatusException: pytest.fail("UnexpectedStatusException was thrown while it should not") @@ -45,7 +56,7 @@ def test_does_not_raise_when_successfully_created_package(): def test_raise_when_failed_created_package(): try: - sagemaker_session = get_sagemaker_session(returns_status="EnRoute") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="EnRoute") sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME) assert ( False @@ -59,7 +70,7 @@ def test_raise_when_failed_created_package(): def test_does_not_raise_when_correct_job_status(): try: job = Mock() - sagemaker_session = get_sagemaker_session(returns_status="Stopped") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Stopped") sagemaker_session._check_job_status( job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus" ) @@ -70,7 +81,7 @@ def test_does_not_raise_when_correct_job_status(): def test_does_raise_when_incorrect_job_status(): try: job = Mock() - sagemaker_session = get_sagemaker_session(returns_status="Failed") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Failed") sagemaker_session._check_job_status( job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus" ) @@ -86,7 +97,7 @@ def test_does_raise_when_incorrect_job_status(): def test_does_not_raise_when_successfully_deployed_endpoint(): try: - sagemaker_session = get_sagemaker_session(returns_status="InService") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="InService") sagemaker_session.wait_for_endpoint(ENDPOINT_NAME) except sagemaker.exceptions.UnexpectedStatusException: pytest.fail("UnexpectedStatusException was thrown while it should not") @@ -94,7 +105,7 @@ def test_does_not_raise_when_successfully_deployed_endpoint(): def test_raise_when_failed_to_deploy_endpoint(): try: - sagemaker_session = get_sagemaker_session(returns_status="Failed") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Failed") assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME) assert ( False @@ -103,3 +114,15 @@ def test_raise_when_failed_to_deploy_endpoint(): assert type(e) == sagemaker.exceptions.UnexpectedStatusException assert e.actual_status == "Failed" assert "InService" in e.allowed_statuses + + +def test_wait_for_endpoint_timeout(): + timeout_seconds = 2 + block_seconds = timeout_seconds + 3 + sagemaker_session = get_sagemaker_session_mock_endpoint_status( + returns_status="InService", block_seconds=block_seconds + ) + start_time = time.time() + sagemaker_session.wait_for_endpoint(ENDPOINT_NAME, 0.1, timeout_seconds) + elapsed_time = time.time() - start_time + assert elapsed_time >= timeout_seconds From 3a3a3d5c49f642e66fce790723739f6f3e7e3b38 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Fri, 29 Oct 2021 19:15:29 -0700 Subject: [PATCH 2/2] Add summary to waiter method --- src/sagemaker/session.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 8dce301600..b4c8729e4a 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4664,14 +4664,15 @@ def _wait_until_training_done(callable_fn, desc, poll=5): def _wait_until(callable_fn, poll_seconds=5, timeout_seconds=None): - """ + """Method to allow waiting for function execution to complete. + Args: - callable_fn: callable to wait for which returns None to keep polling - poll_seconds (float): time to sleep between calls to callable_fn + callable_fn: callable to wait for which returns None to keep polling. + poll_seconds (float): time to sleep between calls to callable_fn. timeout_seconds (float): Optional stop polling after timeout_seconds elapsed. Returns: - Result of the callable_fn + Result of the callable_fn. """ waited_seconds = 0.0 last_time = time.time()