diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index eb276678e4..900f13d4c6 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -157,7 +157,7 @@ def _calculate_transition_duration(trans) -> Tuple[str, str]: def wait( training_job: TrainingJob, poll: int = 5, - timeout: Optional[int] = None + timeout: Optional[int] = 3000 ) -> None: """Wait for training job to complete with progress tracking. @@ -192,8 +192,10 @@ def wait( iteration = 0 while True: iteration += 1 - time.sleep(poll) - training_job.refresh() + time.sleep(1) + if iteration == poll: + training_job.refresh() + iteration = 0 clear_output(wait=True) status = training_job.training_job_status @@ -302,7 +304,7 @@ def wait( raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason) if timeout and elapsed >= timeout: - raise TimeoutExceededError(resouce_type="TrainingJob", status=status) + raise TimeoutExceededError(resource_type="TrainingJob", status=status) else: print(f"\nTrainingJob Name: {training_job.training_job_name}") @@ -363,7 +365,7 @@ def wait( raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason) if timeout and elapsed >= timeout: - raise TimeoutExceededError(resouce_type="TrainingJob", status=status) + raise TimeoutExceededError(resource_type="TrainingJob", status=status) except (FailedStatusError, TimeoutExceededError): diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 66ca88130b..1680d92450 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -261,7 +261,11 @@ def train(self, if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait - _wait(training_job) + from sagemaker.core.utils.exceptions import TimeoutExceededError + try : + _wait(training_job) + except TimeoutExceededError as e: + logger.error("Error: %s", e) self.latest_training_job = training_job return training_job diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index 68d50a2989..230d2566d0 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait - _wait(training_job) + from sagemaker.core.utils.exceptions import TimeoutExceededError + try : + _wait(training_job) + except TimeoutExceededError as e: + logger.error("Error: %s", e) self.latest_training_job = training_job return training_job diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index e14734b692..4274723f5a 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait - _wait(training_job) + from sagemaker.core.utils.exceptions import TimeoutExceededError + try: + _wait(training_job) + except TimeoutExceededError as e: + logger.error("Error: %s", e) self.latest_training_job = training_job return training_job diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 4e109a85b9..17c4ec344d 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -1,3 +1,4 @@ +from logging import exception from typing import Optional, Union import logging from sagemaker.train.base_trainer import BaseTrainer @@ -261,7 +262,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait - _wait(training_job) + from sagemaker.core.utils.exceptions import TimeoutExceededError + try : + _wait(training_job) + except TimeoutExceededError as e: + logger.error("Error: %s", e) self.latest_training_job = training_job return training_job