diff --git a/opennmt/tests/runner_test.py b/opennmt/tests/runner_test.py index b826161ea..87ecbec49 100644 --- a/opennmt/tests/runner_test.py +++ b/opennmt/tests/runner_test.py @@ -115,6 +115,10 @@ def testTrain(self, pass_model_builder): with open(en_file) as f: self.assertEqual(next(f).strip(), "a t z m o n") + # Continue the training without updating max_step + with self.assertRaises(RuntimeError, match="max_step"): + runner.train() + @test_util.run_with_two_cpu_devices def testTrainDistribute(self): ar_file, en_file = self._makeTransliterationData() diff --git a/opennmt/training.py b/opennmt/training.py index f4b870b20..a685874c1 100644 --- a/opennmt/training.py +++ b/opennmt/training.py @@ -73,11 +73,13 @@ def __call__(self, https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. """ if max_step is not None and self._optimizer.iterations.numpy() >= max_step: - tf.get_logger().warning("Model already reached max_step = %d. Exiting.", max_step) - return + raise RuntimeError("The training already reached max_step (%d). If you " + "want to continue the training, you should increase the " + "max_step value in the training parameters." % max_step) if evaluator is not None and evaluator.should_stop(): - tf.get_logger().warning("Early stopping conditions are already met. Exiting.") - return + raise RuntimeError("The early stopping conditions are already met. If you " + "want to continue the training, you should update your " + "early stopping parameters.") self._gradient_accumulator.reset() self._words_counters.clear()