Skip to content

Commit

Permalink
Define base Trainer class (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Jan 16, 2020
1 parent e383c0d commit d7db4b1
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 64 deletions.
5 changes: 2 additions & 3 deletions opennmt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,9 @@ def train(self, num_devices=1, with_eval=False, checkpoint_path=None):
else:
accum_steps = 1

trainer = training_util.Trainer(
trainer = training_util.DistributionStrategyTrainer(
checkpoint,
devices=misc.get_devices(count=num_devices),
mixed_precision=self._mixed_precision)
devices=misc.get_devices(count=num_devices))
trainer(
dataset_fn,
max_step=train_config.get("max_step"),
Expand Down
222 changes: 161 additions & 61 deletions opennmt/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Training related classes and functions."""

import abc
import os
import time

Expand All @@ -9,39 +10,31 @@
from opennmt.utils import misc


class Trainer(object):
"""Model trainer."""
class Trainer(abc.ABC):
"""Base class for model trainer."""

def __init__(self, checkpoint, devices=None, mixed_precision=False):
def __init__(self, checkpoint, is_master=True):
"""Initializes the trainer.
Args:
checkpoint: A :class:`opennmt.utils.checkpoint.Checkpoint` instance.
devices: List of device strings to use for training.
mixed_precision: Whether mixed precision is enabled or not.
checkpoint: A :class:`opennmt.utils.Checkpoint` instance.
is_master: Whether this trainer instance is the master trainer.
"""
if not devices:
devices = misc.get_devices(count=1) # Train with 1 device by default.
self._checkpoint = checkpoint
self._mixed_precision = mixed_precision
self._is_master = is_master
self._model = checkpoint.model
self._strategy = tf.distribute.MirroredStrategy(devices=devices)
self._summary_writer = tf.summary.create_file_writer(checkpoint.model_dir)

optimizer = checkpoint.optimizer
if optimizer is None:
raise ValueError("No optimizer is defined")
if mixed_precision:
graph_optimizer_options = tf.config.optimizer.get_experimental_options()
mixed_precision_enabled = graph_optimizer_options.get("auto_mixed_precision")
if (mixed_precision_enabled
and not isinstance(optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer)):
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
self._optimizer = optimizer

self._words_counters = {}
with self._strategy.scope():
# Create some variables under the strategy scope.
_ = self._optimizer.iterations
self._model.create_variables()
self._gradient_accumulator = optimizer_util.GradientAccumulator()

def __call__(self,
dataset,
max_step=None,
Expand Down Expand Up @@ -73,28 +66,28 @@ def __call__(self,
tf.get_logger().warning("Early stopping conditions are already met. Exiting.")
return

self._gradient_accumulator.reset()
self._words_counters.clear()

last_report_time = time.time()
last_step = 0

with self._summary_writer.as_default():
if self._optimizer.iterations.numpy() == 0:
self._checkpoint.save(0)
self._model.visualize(self._checkpoint.model_dir)

for step, loss in self._steps(dataset, accum_steps=accum_steps, report_steps=report_steps):
last_step = step
iterations = self._optimizer.iterations
tf.summary.experimental.set_step(iterations)

last_report_step = 0
last_report_time = time.time()
for loss in self._steps(dataset, accum_steps=accum_steps, report_steps=report_steps):
if tf.math.is_nan(loss):
raise RuntimeError("Model diverged with loss = NaN.")
step = iterations.numpy()
if step % report_steps == 0:
last_report_time = _report_training_status(
_report_training_status(
step,
loss,
self._optimizer.learning_rate,
self._synchronize_words_counters(),
self._get_words_counters(),
last_report_step,
last_report_time)
if save_steps is not None and step % save_steps == 0:
self._checkpoint.save(step)
last_report_step = step
last_report_time = time.time()
if step == 1 or (save_steps is not None and step % save_steps == 0):
self._save_checkpoint(step)
if evaluator is not None and eval_steps is not None and step % eval_steps == 0:
self._evaluate(evaluator, step, export_on_best=export_on_best)
if evaluator.should_stop():
Expand All @@ -103,26 +96,137 @@ def __call__(self,
if step == max_step:
break

if evaluator is not None and last_step != evaluator.last_evaluated_step:
self._evaluate(evaluator, last_step, export_on_best=export_on_best)
self._checkpoint.save(last_step)
if evaluator is not None and step != evaluator.last_evaluated_step:
self._evaluate(evaluator, step, export_on_best=export_on_best)
self._save_checkpoint(step)

@abc.abstractmethod
def _steps(self, dataset, accum_steps=1, report_steps=None):
"""Returns a generator over training steps (i.e. parameters update).
Args:
dataset: The training dataset.
accum_steps: Accumulate the gradients of this many steps/batches.
report_steps: Report summary statistics every this many steps. This should
typically be used in a ``tf.summary.record_if`` context.
Returns:
A generator that yields a loss value to report for this step.
"""
raise NotImplementedError()

def _get_words_counters(self):
"""Returns the accumulated words counters and resets them.
This is used to report the words per second in the training logs.
Returns:
A dictionary mapping a counter name to a Python value.
"""
return {}

def _run_model(self, source, target):
"""Computes the loss of the given source and target pair.
Args:
source: A nested structure of tensors.
target: A nested structure of tensors.
Returns:
A tuple containing,
- The loss to compute the gradients.
- The loss to report.
"""
first_call = not self._model.built
outputs, _ = self._model(
source,
labels=target,
training=True,
step=self._optimizer.iterations)
loss = self._model.compute_loss(outputs, target, training=True)
if isinstance(loss, tuple):
training_loss = loss[0] / loss[1]
reported_loss = loss[0] / loss[2] if len(loss) > 2 else training_loss
else:
training_loss, reported_loss = loss, loss
training_loss = self._model.regularize_loss(
training_loss, variables=self._model.trainable_variables)
if first_call and self._is_master:
self._model.visualize(self._checkpoint.model_dir)
return training_loss, reported_loss

def _save_checkpoint(self, step):
"""Saves a checkpoint for step."""
if not self._is_master:
return
self._checkpoint.save(step)

def _evaluate(self, evaluator, step, export_on_best=None):
"""Runs evaluation for step."""
if not self._is_master:
return
metrics = evaluator(step)
if export_on_best is not None and evaluator.is_best(export_on_best):
export_dir = os.path.join(self._checkpoint.model_dir, "export", str(step))
tf.get_logger().info("Exporting SavedModel to %s (best %s so far: %f)",
export_dir, export_on_best, metrics[export_on_best])
self._model.export(export_dir)


class BasicTrainer(Trainer):
"""Basic single GPU trainer."""

def _steps(self, dataset, accum_steps=1, report_steps=None):
if accum_steps != 1:
raise ValueError("BasicTrainer does not support gradient accumulation")
if callable(dataset):
dataset = dataset(tf.distribute.InputContext())

@tf.function(input_signature=dataset.element_spec)
def _step(source, target):
training_loss, reported_loss = self._run_model(source, target)
variables = self._model.trainable_variables
gradients = self._optimizer.get_gradients(training_loss, variables)
self._optimizer.apply_gradients(list(zip(gradients, variables)))
return reported_loss

for source, target in dataset:
yield _step(source, target)


class DistributionStrategyTrainer(Trainer):
"""Trainer based on distribution strategies."""

def __init__(self, checkpoint, devices=None):
"""Initializes the trainer.
Args:
checkpoint: A :class:`opennmt.utils.checkpoint.Checkpoint` instance.
devices: List of device strings to use for training.
"""
super(DistributionStrategyTrainer, self).__init__(checkpoint)
if not devices:
devices = misc.get_devices(count=1) # Train with 1 device by default.
self._strategy = tf.distribute.MirroredStrategy(devices=devices)
self._words_counters = {}
with self._strategy.scope():
# Create some variables under the strategy scope.
_ = self._optimizer.iterations
self._gradient_accumulator = optimizer_util.GradientAccumulator()

def _get_words_counters(self):
return {name:value.numpy() for name, value in self._synchronize_words_counters().items()}

def _steps(self, dataset, accum_steps=1, report_steps=None):
"""Returns a generator over training steps."""
self._gradient_accumulator.reset()
self._words_counters.clear()
for i, loss in enumerate(self._accumulate_next_gradients(dataset, report_steps=report_steps)):
if tf.math.is_nan(loss):
raise RuntimeError("Model diverged with loss = NaN.")
if i == 0 or (i + 1) % accum_steps == 0:
self._apply_gradients()
yield self._optimizer.iterations.numpy(), loss
yield loss

def _accumulate_next_gradients(self, dataset, report_steps=None):
"""Accumulates the gradients from the next element in :obj:`dataset`."""
Expand All @@ -141,7 +245,6 @@ def _accumulate_next_gradients(self, dataset, report_steps=None):

@tf.function
def _accumulate_next():
tf.summary.experimental.set_step(self._optimizer.iterations)
if report_steps is None:
should_record_summaries = False
else:
Expand All @@ -168,19 +271,8 @@ def _accumulate_gradients(self, per_replica_source, per_replica_target):

def _accumulate_gradients_on_replica(self, source, target):
"""Accumulates the gradients (in replica)."""
outputs, _ = self._model(
source,
labels=target,
training=True,
step=self._optimizer.iterations)
loss = self._model.compute_loss(outputs, target, training=True)
if isinstance(loss, tuple):
training_loss = loss[0] / loss[1]
reported_loss = loss[0] / loss[2] if len(loss) > 2 else training_loss
else:
training_loss, reported_loss = loss, loss
training_loss, reported_loss = self._run_model(source, target)
variables = self._model.trainable_variables
training_loss = self._model.regularize_loss(training_loss, variables=variables)
gradients = self._optimizer.get_gradients(training_loss, variables)
self._gradient_accumulator(gradients)
tf.summary.scalar("gradients/global_norm", tf.linalg.global_norm(gradients))
Expand Down Expand Up @@ -234,27 +326,35 @@ def _apply_gradients_on_replica(self):
self._gradient_accumulator.reset()


def _report_training_status(step, loss, learning_rate, words_counters, last_report_time):
tf.summary.experimental.set_step(step)
new_report_time = time.time()
def _report_training_status(step,
loss,
learning_rate,
words_counters,
last_report_step,
last_report_time):
elapsed_time = time.time() - last_report_time

steps_per_sec = (step - last_report_step) / elapsed_time
tf.summary.scalar("steps_per_sec", steps_per_sec, description="Training steps per second")
steps_per_sec_fmt = "steps/s = %0.2f" % steps_per_sec

words_per_sec_fmt = []
for name, counter in words_counters.items():
avg = int(counter.numpy() / (new_report_time - last_report_time))
avg = int(counter / elapsed_time)
tf.summary.scalar(
"words_per_sec/%s" % name,
avg,
description="%s words per second" % name.capitalize())
fmt = "%s words/s = %d" % (name, avg)
words_per_sec_fmt.append(fmt)
words_per_sec_fmt = sorted(words_per_sec_fmt)
words_per_sec_fmt.append("%s words/s = %d" % (name, avg))

if isinstance(learning_rate, tf.optimizers.schedules.LearningRateSchedule):
learning_rate = learning_rate(step)

tf.get_logger().info(
"Step = %d ; %s ; Learning rate = %f ; Loss = %f",
step,
", ".join(words_per_sec_fmt),
", ".join([steps_per_sec_fmt] + list(sorted(words_per_sec_fmt))),
learning_rate,
loss)
tf.summary.scalar("loss", loss, description="Training loss")
tf.summary.scalar("optim/learning_rate", learning_rate, description="Learning rate")
return new_report_time

0 comments on commit d7db4b1

Please sign in to comment.