Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Make test_trainer_respects_keep_serialized_model_every_num_seconds mo…
Browse files Browse the repository at this point in the history
…re patient (#1358)
  • Loading branch information
nelson-liu authored and matt-gardner committed Jun 12, 2018
1 parent ac2e0b9 commit 6739c31
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions allennlp/tests/training/trainer_test.py
Expand Up @@ -246,15 +246,15 @@ def test_trainer_respects_num_serialized_models_to_keep(self):

def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
# To test:
# Create an iterator that sleeps for 0.5 second per epoch, so the total training
# time for one epoch is slightly greater then 0.5 seconds.
# Run for 6 epochs, keeping the last 2 models, models also kept every 1 second.
# Create an iterator that sleeps for 2.5 second per epoch, so the total training
# time for one epoch is slightly greater then 2.5 seconds.
# Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds.
# Check the resulting checkpoints. Should then have models at epochs
# 2, 4, plus the last two at 5 and 6.
class WaitingIterator(BasicIterator):
# pylint: disable=arguments-differ
def _create_batches(self, *args, **kwargs):
time.sleep(0.5)
time.sleep(2.5)
return super(WaitingIterator, self)._create_batches(*args, **kwargs)

iterator = WaitingIterator(batch_size=2)
Expand All @@ -264,7 +264,7 @@ def _create_batches(self, *args, **kwargs):
iterator, self.instances, num_epochs=6,
serialization_dir=self.TEST_DIR,
num_serialized_models_to_keep=2,
keep_serialized_model_every_num_seconds=1)
keep_serialized_model_every_num_seconds=5)
trainer.train()

# Now check the serialized files
Expand Down

0 comments on commit 6739c31

Please sign in to comment.