Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop Waiting For Collection Files If Training Has Ended #51

Merged
merged 4 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions smdebug/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ def __str__(self):
return "Step {} of mode {} not yet available".format(self.step, self.mode.name)


class MissingCollectionFiles(Exception):
def __init__(self):
pass

def __str__(self):
return "Training job has ended. All the collection files could not be loaded"


class IndexReaderException(Exception):
def __init__(self, message):
self.message = message
Expand Down
28 changes: 16 additions & 12 deletions smdebug/trials/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
match_inc,
serialize_tf_device,
)
from smdebug.exceptions import NoMoreData, StepUnavailable, TensorUnavailable
from smdebug.exceptions import (
MissingCollectionFiles,
NoMoreData,
StepUnavailable,
TensorUnavailable,
)


class Trial(ABC):
Expand Down Expand Up @@ -149,22 +154,21 @@ def _fetch():
"Waiting to read collections files generated by the training job."
)

def _wait_for_first_collection_file():
while len(collection_files) == 0:
time.sleep(2)
_fetch()

def _wait_for_all_collection_files():
while len(collection_files) < self.num_workers:
def _wait_for_collection_files(number_of_collection_file_to_wait_for):
while len(collection_files) < number_of_collection_file_to_wait_for:
time.sleep(2)
_fetch()
for collection_file in collection_files:
self.worker_set.add(get_worker_name_from_collection_file(collection_file))
if has_training_ended(self.path):
""" _fetch should have returned all the collection files if the training job has ended """
if len(collection_files) < number_of_collection_file_to_wait_for:
raise MissingCollectionFiles

_fetch()
_wait_for_first_collection_file()
_wait_for_collection_files(1) # wait for the first collection file
self._read_collections(collection_files)
_wait_for_all_collection_files()
_wait_for_collection_files(self.num_workers) # wait for all the collection files
for collection_file in collection_files:
self.worker_set.add(get_worker_name_from_collection_file(collection_file))

@abstractmethod
def _load_tensors_from_index_tensors(self, index_tensors_dict):
Expand Down
65 changes: 65 additions & 0 deletions tests/analysis/trials/test_load_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Standard Library

# Third Party
import pytest

# First Party
from smdebug.exceptions import MissingCollectionFiles
from smdebug.trials import create_trial


@pytest.mark.slow
def test_load_collection_files_from_completed_job():
"""
Number of collection files : 2001
Training_has_ended.ts : Present

All the collection files have been written in the test dataset
and the training_has_ended file is present
:return:
"""
path = "s3://tornasole-testing/collection-tests/all-collection-files-present/"
try:
trial = create_trial(path)
except MissingCollectionFiles:
assert False
assert len(trial.workers()) == 2001


@pytest.mark.slow
def test_load_collection_files_from_completed_job_with_missing_files():
"""
Number of collection files : 1446
Training_has_ended.ts : Present

Some of the collection files have been removed in the test dataset.
The number of expected collection files is supposed to 2001
but the training_has_ended file is present so we stop waiting
:return:
"""
path = "s3://tornasole-testing/collection-tests/collection-files-missing/"
try:
trial = create_trial(path)
assert False
except MissingCollectionFiles:
assert True


@pytest.mark.slow
def test_load_collection_files_from_incomplete_job():
"""
Number of collection files : 2001
Training_has_ended.ts : Absent

All the collection files have been written in the test dataset
and the training_has_ended file is absent


:return:
"""
path = "s3://tornasole-testing/collection-tests/all-collection-files-present-job-incomplete/"
try:
trial = create_trial(path)
except MissingCollectionFiles:
assert False
assert len(trial.workers()) == 2001