Skip to content

Commit

Permalink
Add sentence weighting (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Jan 7, 2021
1 parent 291fe9b commit e4d2ebb
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 58 deletions.
4 changes: 4 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ data:
# (optional) Pharaoh alignments of the training files.
train_alignments: data/toy-ende/alignments-train.txt

# (optional) File containing the weight of each example (one weight per line).
# The loss value of each example is multiplied by its corresponding weight.
example_weights: data/toy-ende/weights-train.txt

# (required for train_end_eval and eval run types).
eval_features_file: data/toy-ende/src-val.txt
eval_labels_file: data/toy-ende/tgt-val.txt
Expand Down
71 changes: 70 additions & 1 deletion opennmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,16 +697,28 @@ def make_training_dataset(self,
return dataset


def _register_example_weight(features, labels, weight):
labels["weight"] = tf.strings.to_number(weight)
return features, labels


class ExampleInputter(ParallelInputter, ExampleInputterAdapter):
"""An inputter that returns training examples (parallel features and labels)."""

def __init__(self, features_inputter, labels_inputter, share_parameters=False):
def __init__(self,
features_inputter,
labels_inputter,
share_parameters=False,
accepted_annotations=None):
"""Initializes this inputter.
Args:
features_inputter: An inputter producing the features (source).
labels_inputter: An inputter producing the labels (target).
share_parameters: Share the inputters parameters.
accepted_annotations: An optional dictionary mapping annotation names in
the data configuration (e.g. "train_alignments") to a callable with
signature ``(features, labels, annotations) -> (features, labels)``.
"""
self.features_inputter = features_inputter
self.labels_inputter = labels_inputter
Expand All @@ -718,6 +730,63 @@ def __init__(self, features_inputter, labels_inputter, share_parameters=False):
self.features_inputter.asset_prefix = "source_"
self.labels_inputter.asset_prefix = "target_"

self.accepted_annotations = accepted_annotations or {}
self.accepted_annotations["example_weights"] = _register_example_weight
self.annotation_files = {}

def initialize(self, data_config):
super().initialize(data_config)

# Check if some accepted annotations are defined in the data configuration.
for annotation in self.accepted_annotations.keys():
path = data_config.get(annotation)
if path is not None:
self.annotation_files[annotation] = path

def make_dataset(self, data_file, training=None):
dataset = super().make_dataset(data_file, training=training)
if not training or not self.annotation_files:
return dataset

# Some annotations are configured and should be zipped to the training dataset.
all_annotation_datasets = tf.nest.map_structure(tf.data.TextLineDataset, self.annotation_files)

# Common case of a non-weighted dataset.
if not isinstance(dataset, list):
return tf.data.Dataset.zip({"examples": dataset, **all_annotation_datasets})

# Otherwise, there should be as many annotations datasets as input datasets.
datasets = dataset
for name, annotation_datasets in all_annotation_datasets.items():
num_annotation_datasets = (
len(annotation_datasets) if isinstance(annotation_datasets, list) else 1)
if num_annotation_datasets != len(datasets):
raise ValueError("%d '%s' files were provided, but %d were expected to match the "
"number of data files" % (num_annotation_datasets, name, len(datasets)))

# Convert dict of lists to list of dicts.
all_annotation_datasets = [
dict(zip(all_annotation_datasets, t)) for t in zip(*all_annotation_datasets.values())]

return [
tf.data.Dataset.zip({"examples": dataset, **annotation_datasets})
for dataset, annotation_datasets in zip(datasets, all_annotation_datasets)]

def make_features(self, element=None, features=None, training=None):
if training and self.annotation_files:
annotations = element.copy()
example = annotations.pop("examples")
else:
annotations = {}
example = element

features, labels = super().make_features(element=example, features=features, training=training)

# Load each annotation into the features and labels dict.
for name, annotation in annotations.items():
features, labels = self.accepted_annotations[name](features, labels, annotation)
return features, labels

def make_inference_dataset(self,
features_file,
batch_size,
Expand Down
1 change: 1 addition & 0 deletions opennmt/models/sequence_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def compute_loss(self, outputs, labels, training=True):
return cross_entropy_loss(
outputs,
labels["classes_id"],
weight=labels.get("weight"),
label_smoothing=self.params.get("label_smoothing", 0.0),
training=training)

Expand Down
3 changes: 2 additions & 1 deletion opennmt/models/sequence_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def compute_loss(self, outputs, labels, training=True):
return cross_entropy_sequence_loss(
outputs,
labels["tags_id"],
labels["length"],
sequence_length=labels["length"],
sequence_weight=labels.get("weight"),
label_smoothing=self.params.get("label_smoothing", 0.0),
average_in_time=self.params.get("average_loss_in_time", False),
training=training)
Expand Down
50 changes: 12 additions & 38 deletions opennmt/models/sequence_to_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ def compute_loss(self, outputs, labels, training=True):
loss, loss_normalizer, loss_token_normalizer = losses.cross_entropy_sequence_loss(
logits,
labels["ids_out"],
labels["length"],
sequence_length=labels["length"],
sequence_weight=labels.get("weight"),
label_smoothing=params.get("label_smoothing", 0.0),
average_in_time=params.get("average_loss_in_time", False),
training=training)
Expand Down Expand Up @@ -419,44 +420,17 @@ def __init__(self,
labels_inputter,
share_parameters=False):
super().__init__(
features_inputter, labels_inputter, share_parameters=share_parameters)
features_inputter,
labels_inputter,
share_parameters=share_parameters,
accepted_annotations={"train_alignments": self._register_alignment})
labels_inputter.set_decoder_mode(mark_start=True, mark_end=True)
self.alignment_file = None

def initialize(self, data_config):
super().initialize(data_config)
self.alignment_file = data_config.get("train_alignments")

def make_dataset(self, data_file, training=None):
dataset = super().make_dataset(
data_file, training=training)
if self.alignment_file is None or not training:
return dataset
if not isinstance(dataset, list):
return tf.data.Dataset.zip((dataset, tf.data.TextLineDataset(self.alignment_file)))
datasets = dataset
alignment_files = self.alignment_file
if not isinstance(alignment_files, list):
alignment_files = [alignment_files]
if len(alignment_files) != len(datasets):
raise ValueError("%d alignment files were provided, but %d were expected to match the "
"number of data files" % (len(alignment_files), len(datasets)))
return [
tf.data.Dataset.zip((dataset, tf.data.TextLineDataset(alignment_file)))
for dataset, alignment_file in zip(datasets, alignment_files)]

def make_features(self, element=None, features=None, training=None):
if training and self.alignment_file is not None:
element, alignment = element
else:
alignment = None
features, labels = super().make_features(
element=element, features=features, training=training)
if alignment is not None:
labels["alignment"] = text.alignment_matrix_from_pharaoh(
alignment,
self.features_inputter.get_length(features, ignore_special_tokens=True),
self.labels_inputter.get_length(labels, ignore_special_tokens=True))

def _register_alignment(self, features, labels, alignment):
labels["alignment"] = text.alignment_matrix_from_pharaoh(
alignment,
self.features_inputter.get_length(features, ignore_special_tokens=True),
self.labels_inputter.get_length(labels, ignore_special_tokens=True))
return features, labels


Expand Down
41 changes: 41 additions & 0 deletions opennmt/tests/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,47 @@

class LossesTest(tf.test.TestCase):

def testCrossEntropySequenceLoss(self):
logits = tf.constant([
[[0.1, 0.2, 0.9], [-1.2, 2.1, 0], [0.6, 0.3, 0.4]],
[[-2.2, -0.2, -1.2], [2.3, 0.2, -0.1], [0.0, 0.1, 0.7]]])
labels = tf.constant([[2, 1, 0], [1, 0, 2]], dtype=tf.int32)

loss, training_norm, stats_norm = losses.cross_entropy_sequence_loss(
logits, labels, training=True)
self.assertNear(loss, 3.06985, 1e-5)
self.assertEqual(training_norm, 2)
self.assertEqual(stats_norm, 6)

_, training_norm, stats_norm = losses.cross_entropy_sequence_loss(
logits, labels, average_in_time=True, training=True)
self.assertEqual(training_norm, 6)
self.assertEqual(stats_norm, 6)

def testMaskedCrossEntropySequenceLoss(self):
logits = tf.constant([
[[0.1, 0.2, 0.9], [-1.2, 2.1, 0], [0.6, 0.3, 0.4]],
[[-2.2, -0.2, -1.2], [2.3, 0.2, -0.1], [0.0, 0.1, 0.7]]])
labels = tf.constant([[2, 1, 0], [1, 0, 2]], dtype=tf.int32)
lengths = tf.constant([2, 1], dtype=tf.int32)

loss, _, stats_norm = losses.cross_entropy_sequence_loss(
logits, labels, sequence_length=lengths, training=True)
self.assertNear(loss, 1.22118, 1e-5)
self.assertEqual(stats_norm, 3)

def testWeightedAndMaskedCrossEntropySequenceLoss(self):
logits = tf.constant([
[[0.1, 0.2, 0.9], [-1.2, 2.1, 0], [0.6, 0.3, 0.4]],
[[-2.2, -0.2, -1.2], [2.3, 0.2, -0.1], [0.0, 0.1, 0.7]]])
labels = tf.constant([[2, 1, 0], [1, 0, 2]], dtype=tf.int32)
lengths = tf.constant([3, 2], dtype=tf.int32)
weights = tf.constant([0.6, 1.2])

loss, _, _ = losses.cross_entropy_sequence_loss(
logits, labels, sequence_length=lengths, sequence_weight=weights, training=True)
self.assertNear(loss, 1.77306, 1e-5)

@parameterized.expand([
["l1", 1e-4],
["L1", 1e-4],
Expand Down
20 changes: 18 additions & 2 deletions opennmt/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _seq2seq_model(training=None):

class ModelTest(tf.test.TestCase):

def _makeToyEnDeData(self, with_alignments=False):
def _makeToyEnDeData(self, with_alignments=False, with_weights=False):
data_config = {}
features_file = test_util.make_data_file(
os.path.join(self.get_temp_dir(), "src.txt"),
Expand All @@ -55,10 +55,13 @@ def _makeToyEnDeData(self, with_alignments=False):
if with_alignments:
# Dummy and incomplete alignments.
data_config["train_alignments"] = test_util.make_data_file(
os.path.join(self.get_temp_dir(), "aligne.txt"),
os.path.join(self.get_temp_dir(), "alignments.txt"),
["0-0 1-0 2-2 3-4 4-4 5-6",
"0-1 1-1 1-3 2-3 4-4",
"0-0 1-0 2-2 3-4 4-4 5-6"])
if with_weights:
data_config["example_weights"] = test_util.make_data_file(
os.path.join(self.get_temp_dir(), "weights.txt"), ["0.6", "1", "1e-2"])
return features_file, labels_file, data_config

def _makeToyLMData(self):
Expand Down Expand Up @@ -242,6 +245,19 @@ def testSequenceToSequenceWithGuidedAlignmentAndWeightedDataset(self):
[features_file, features_file], [labels_file, labels_file], 16)
self.assertIsInstance(dataset, tf.data.Dataset)

def testSequenceToSequenceWithWeightedExamples(self):
model, params = _seq2seq_model(training=True)
features_file, labels_file, data_config = self._makeToyEnDeData(with_weights=True)
model.initialize(data_config, params=params)
dataset = model.examples_inputter.make_training_dataset(features_file, labels_file, 16)
features, labels = next(iter(dataset))
self.assertIn("weight", labels)
outputs, _ = model(features, labels=labels, training=True)
weighted_loss, _, _ = model.compute_loss(outputs, labels, training=True)
labels.pop("weight")
default_loss, _, _ = model.compute_loss(outputs, labels, training=True)
self.assertNotEqual(weighted_loss, default_loss)

def testSequenceToSequenceWithReplaceUnknownTarget(self):
model, params = _seq2seq_model()
params["replace_unknown_target"] = True
Expand Down
46 changes: 30 additions & 16 deletions opennmt/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,55 +24,69 @@ def _softmax_cross_entropy(logits, labels, label_smoothing, training):

def cross_entropy_sequence_loss(logits,
labels,
sequence_length,
sequence_length=None,
label_smoothing=0.0,
average_in_time=False,
training=None):
training=None,
sequence_weight=None):
"""Computes the cross entropy loss of sequences.
Args:
logits: The unscaled probabilities.
labels: The true labels.
sequence_length: The length of each sequence.
logits: The unscaled probabilities with shape :math:`[B, T, V]`.
labels: The true labels with shape :math:`[B, T]`.
sequence_length: The length of each sequence with shape :math:`[B]`.
label_smoothing: The label smoothing value.
average_in_time: If ``True``, also average the loss in the time dimension.
training: Compute training loss.
sequence_weight: The weight of each sequence with shape :math:`[B]`.
Returns:
A tuple (cumulated loss, loss normalizer, token-level normalizer).
"""
batch_size = tf.shape(logits)[0]
max_time = tf.shape(logits)[1]

cross_entropy = _softmax_cross_entropy(logits, labels, label_smoothing, training)
weights = tf.sequence_mask(
sequence_length, maxlen=max_time, dtype=cross_entropy.dtype)
loss = tf.reduce_sum(cross_entropy * weights)
loss_token_normalizer = tf.reduce_sum(weights)
dtype = cross_entropy.dtype

if sequence_weight is not None:
cross_entropy *= tf.expand_dims(tf.cast(sequence_weight, dtype), 1)

if sequence_length is not None:
max_time = tf.shape(logits)[1]
mask = tf.sequence_mask(sequence_length, maxlen=max_time, dtype=dtype)
cross_entropy *= mask
loss_token_normalizer = tf.reduce_sum(mask)
else:
loss_token_normalizer = tf.cast(tf.size(labels), dtype)

loss = tf.reduce_sum(cross_entropy)

if average_in_time or not training:
loss_normalizer = loss_token_normalizer
else:
loss_normalizer = tf.cast(batch_size, loss.dtype)
batch_size = tf.shape(logits)[0]
loss_normalizer = tf.cast(batch_size, dtype)

return loss, loss_normalizer, loss_token_normalizer

def cross_entropy_loss(logits,
labels,
label_smoothing=0.0,
training=None):
training=None,
weight=None):
"""Computes the cross entropy loss.
Args:
logits: The unscaled probabilities.
labels: The true labels.
logits: The unscaled probabilities with shape :math:`[B, V]`.
labels: The true labels with shape :math:`[B]`.
label_smoothing: The label smoothing value.
training: Compute training loss.
weight: The weight of each example with shape :math:`[B]`.
Returns:
The cumulated loss and the loss normalizer.
"""
cross_entropy = _softmax_cross_entropy(logits, labels, label_smoothing, training)
if weight is not None:
cross_entropy *= tf.cast(weight, cross_entropy.dtype)
loss = tf.reduce_sum(cross_entropy)
loss_normalizer = tf.cast(tf.shape(cross_entropy)[0], loss.dtype)
return loss, loss_normalizer
Expand Down

0 comments on commit e4d2ebb

Please sign in to comment.