Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _calculate_transition_duration(trans) -> Tuple[str, str]:
def wait(
training_job: TrainingJob,
poll: int = 5,
timeout: Optional[int] = None
timeout: Optional[int] = 3000
) -> None:
"""Wait for training job to complete with progress tracking.

Expand Down Expand Up @@ -192,8 +192,10 @@ def wait(
iteration = 0
while True:
iteration += 1
time.sleep(poll)
training_job.refresh()
time.sleep(1)
if iteration == poll:
training_job.refresh()
iteration = 0
clear_output(wait=True)

status = training_job.training_job_status
Expand Down Expand Up @@ -302,7 +304,7 @@ def wait(
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)

if timeout and elapsed >= timeout:
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
raise TimeoutExceededError(resource_type="TrainingJob", status=status)

else:
print(f"\nTrainingJob Name: {training_job.training_job_name}")
Expand Down Expand Up @@ -363,7 +365,7 @@ def wait(
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)

if timeout and elapsed >= timeout:
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
raise TimeoutExceededError(resource_type="TrainingJob", status=status)


except (FailedStatusError, TimeoutExceededError):
Expand Down
6 changes: 5 additions & 1 deletion sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,11 @@ def train(self,

if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
_wait(training_job)
from sagemaker.core.utils.exceptions import TimeoutExceededError
try :
_wait(training_job)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

self.latest_training_job = training_job
return training_job
Expand Down
6 changes: 5 additions & 1 deletion sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati

if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
_wait(training_job)
from sagemaker.core.utils.exceptions import TimeoutExceededError
try :
_wait(training_job)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

self.latest_training_job = training_job
return training_job
Expand Down
6 changes: 5 additions & 1 deletion sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,

if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
_wait(training_job)
from sagemaker.core.utils.exceptions import TimeoutExceededError
try:
_wait(training_job)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

self.latest_training_job = training_job
return training_job
7 changes: 6 additions & 1 deletion sagemaker-train/src/sagemaker/train/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import exception
from typing import Optional, Union
import logging
from sagemaker.train.base_trainer import BaseTrainer
Expand Down Expand Up @@ -261,7 +262,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati

if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
_wait(training_job)
from sagemaker.core.utils.exceptions import TimeoutExceededError
try :
_wait(training_job)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

self.latest_training_job = training_job
return training_job
Expand Down
Loading