Skip to content

Commit

Permalink
Return a summary of the training from the train method (#724)
Browse files Browse the repository at this point in the history
As part of this change, the training statistics management is moved to
a separate class to not make the training loop more complex.
  • Loading branch information
guillaumekln committed Oct 14, 2020
1 parent 36524df commit 34a5f52
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 101 deletions.
28 changes: 20 additions & 8 deletions opennmt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,24 @@ def _init_model(self, config):
model.initialize(config["data"], params=config["params"])
return model

def train(self, num_devices=1, with_eval=False, checkpoint_path=None, hvd=None):
def train(self,
num_devices=1,
with_eval=False,
checkpoint_path=None,
hvd=None,
return_summary=False):
"""Runs the training loop.
Args:
num_devices: Number of devices to use for training.
with_eval: Enable evaluation during training.
checkpoint_path: The checkpoint path to load the model weights from it.
hvd: Optional Horovod module.
return_summary: Return a summary of the training from this function.
Returns:
The path to the final model directory.
The path to the final model directory and, if :obj:`return_summary` is set,
a dictionary with various training statistics.
"""
if hvd is None:
num_replicas = num_devices
Expand Down Expand Up @@ -227,7 +234,7 @@ def train(self, num_devices=1, with_eval=False, checkpoint_path=None, hvd=None):
else:
trainer = training_util.Trainer(model, optimizer, checkpoint=checkpoint)

trainer(
summary = trainer(
dataset_fn,
max_step=train_config.get("max_step"),
accum_steps=accum_steps,
Expand All @@ -237,14 +244,19 @@ def train(self, num_devices=1, with_eval=False, checkpoint_path=None, hvd=None):
eval_steps=eval_config.get("steps", 5000),
moving_average_decay=train_config.get("moving_average_decay"))

if checkpoint is None:
return None
average_last_checkpoints = train_config.get("average_last_checkpoints", 0)
if average_last_checkpoints > 0:
return self.average_checkpoints(
if checkpoint is None:
output_dir = None
elif average_last_checkpoints > 0:
output_dir = self.average_checkpoints(
os.path.join(checkpoint.model_dir, "avg"),
max_count=average_last_checkpoints)
return checkpoint.model_dir
else:
output_dir = checkpoint.model_dir

if return_summary:
return output_dir, summary
return output_dir

def evaluate(self, features_file=None, labels_file=None, checkpoint_path=None):
"""Runs evaluation.
Expand Down
3 changes: 2 additions & 1 deletion opennmt/tests/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def testTrain(self, pass_model_builder):
}
}
runner = self._getTransliterationRunner(config, pass_model_builder=pass_model_builder)
avg_dir = runner.train()
avg_dir, summary = runner.train(return_summary=True)
self.assertEqual(runner.model_dir, avg_dir)
self.assertIsInstance(summary, dict)
self.assertEndsWith(tf.train.latest_checkpoint(avg_dir), "145002")
self.assertLen(tf.train.get_checkpoint_state(avg_dir).all_model_checkpoint_paths, 1)
model_dir = os.path.dirname(avg_dir)
Expand Down
46 changes: 46 additions & 0 deletions opennmt/tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,52 @@ def testEmptyTrainingDataset(self):
with self.assertRaisesRegex(RuntimeError, "No training steps"):
trainer(dataset)

def testTrainingStats(self):
model = _make_seq2seq_model(self.get_temp_dir())
optimizer = tf.keras.optimizers.SGD(1.0)
stats = training.TrainingStats(model, optimizer, warmup_steps=2)

def _step(source_length, target_length, step, loss):
source_features = {"length": source_length}
target_features = {"length": target_length}
stats.update_on_example(source_features, target_features)
stats.update_on_step(step, loss)

_step(24, 23, 5, 9.8)
_step(10, 8, 10, 9.6)

summary = stats.get_last_summary()
self.assertEqual(summary["learning_rate"], 1.0)
self.assertEqual(summary["step"], 10)
self.assertEqual(summary["loss"], 9.6)

# Throughput values are ignored in the 2 first steps.
self.assertEqual(summary["steps_per_sec"], 0)
self.assertEqual(summary["words_per_sec"]["source"], 0)
self.assertEqual(summary["words_per_sec"]["target"], 0)

_step(14, 21, 15, 9.4)

summary = stats.get_last_summary()
self.assertNotEqual(summary["steps_per_sec"], 0)
self.assertNotEqual(summary["words_per_sec"]["source"], 0)
self.assertNotEqual(summary["words_per_sec"]["target"], 0)

stats.log()

# log() should reset accumulated values.
summary = stats.get_last_summary()
self.assertEqual(summary["steps_per_sec"], 0)
self.assertEqual(summary["words_per_sec"]["source"], 0)
self.assertEqual(summary["words_per_sec"]["target"], 0)

summary = stats.get_global_summary()
self.assertEqual(summary["last_learning_rate"], 1.0)
self.assertEqual(summary["last_step"], 15)
self.assertEqual(summary["last_loss"], 9.4)
self.assertEqual(summary["average_loss"], 9.6)
self.assertEqual(summary["num_steps"], 3)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 34a5f52

Please sign in to comment.