Skip to content

Commit

Permalink
Fix the inference of transformer-based models trained with masked lan…
Browse files Browse the repository at this point in the history
…guage modeling (#909)

* add the inference fix to ReplaceMaskedEmbeddings

* first solution for inference support

* updates based on PR comments

* Apply suggestions from code review

Co-authored-by: Gabriel Moreira <gmoreira@nvidia.com>

* Use merlin-dataloader package (#845)

* Use merlin-dataloader package

* remove torch.dataset in favor of merlin.loader.torch

* update dressipi notebook

* minor clean up

* Completely removes models DataLoader

* Installs merlin-dataloader in github actions

* Adds back the stop method

* dataloader can produce sparse tensors using value counts

* remove data files

* fix torch tests

* add missing target to dlrm test

* use loader.peek()

* add some comments to help understand horovod tests

* make sparse tensors optional

* cleanup

* fix spelling

* fix merge

* replace while loop with for loop in horovod test

* use loader context mananger

* Update according to dataloader changes #80

* restore tox.ini

* restore gh workflow

* revert generator changes

* Restore documentation build (#916)

- Change Python 3.9.7 to 3.8.
- Update the versions of the GH actions.
- Update pre-commit config file to get
  flake8 from GitHub instead of GitLab.

* Support `tuple` return type from model `pre` and update test to use this (#890)

* Support `tuple` return typee from `pre` arg to `evaluate`, `predict`

* Update CLM transformer test  to use `pre` instead of Loader `transform`

* Update youtube dnn tests to use transform as model fit pre

* Add `pre` to ModelBlock fit/evaluate

* Revert "Add `pre` to ModelBlock fit/evaluate"

This reverts commit 1eef7b8.

* Raise exception if ragged/sparse tensors are passed at training time.

* Update model_test helper to avoid passing ragged tensors to `fit`

* Handle x and y in model_test

* Change process_lists param to False by default

* Convert to tuple in test loader

* Move order of ragged tensor assertion to before train_pre call

* expand dims in test_classification

* pass transform as pre in test in batch negatives

* Update continuous and retrieval tests

* Remove test of sequence predict functions with loader

* Update error message about ragged tensors for clarity

* Add explanation about why the input types are restricted

* Rename dataset to dataloader in model_test

Co-authored-by: rnyak <ronayak@hotmail.com>

* add assertion check to TransformerInferenceHiddenState

Co-authored-by: Gabriel Moreira <gmoreira@nvidia.com>
Co-authored-by: edknv <109497216+edknv@users.noreply.github.com>
Co-authored-by: mikemckiernan <mmckiernan@nvidia.com>
Co-authored-by: Oliver Holworthy <oholworthy@nvidia.com>
Co-authored-by: rnyak <ronayak@hotmail.com>
  • Loading branch information
6 people committed Dec 13, 2022
1 parent a8ab140 commit df84c81
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 20 deletions.
1 change: 1 addition & 0 deletions merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
from merlin.models.tf.transforms.sequence import (
ReplaceMaskedEmbeddings,
SequenceMaskLast,
SequenceMaskLastInference,
SequenceMaskRandom,
SequencePredictLast,
SequencePredictNext,
Expand Down
2 changes: 1 addition & 1 deletion merlin/models/tf/transformers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.transformer = get_tf_main_layer(transformer)
else:
self.transformer = transformer

self.transformer.supports_masking = True
if "transformer" in inspect.signature(transformer_pre.__init__).parameters:
transformer_pre = transformer_pre(transformer=self.transformer)
self.transformer_pre = transformer_pre
Expand Down
67 changes: 67 additions & 0 deletions merlin/models/tf/transformers/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,66 @@ class LastHiddenState(Layer):
The output class returned by the HuggingFace transformer layer
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(self, inputs: TFBaseModelOutputWithPoolingAndCrossAttentions):
return inputs.last_hidden_state


@Block.registry.register("inference_hidden_state")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class TransformerInferenceHiddenState(Layer):
"""A post-processing layer to select the hidden state
of the next-item position, during inference.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(
self,
inputs: tf.Tensor,
training: bool = False,
testing: bool = False,
):
"""Select the hidden state of the target (last) position, during inference.
During training or testing, the inputs are returned
without any processing.
Parameters:
----------
inputs: tf.Tensor
The 3-D output tensor returned by the transformer block
training : bool, optional
Flag that indicates whether in training mode, by default True
testing : bool, optional
Flag that indicates whether in evaluation mode, by default True
Returns
-------
tf.Tensor
If inference, returns a 2-D tensor with the hidden states of
the target position
"""
batch_size = tf.shape(inputs)[0]
if not training and not testing:
if getattr(inputs, "_keras_mask", None) is not None:
inputs = tf.reshape(
tf.boolean_mask(inputs, inputs._keras_mask), (-1, inputs.shape[-1])
)
tf.debugging.assert_equal(
tf.shape(inputs)[0],
batch_size,
f"The resulting tensor has {tf.shape(inputs)[0]} rows, which does not match"
f" the inputs batch-size {batch_size}. During inference only one position "
"candidate (the last one) should be masked per example",
)
return inputs


@Block.registry.register("pooler_output")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class PoolerOutput(Layer):
Expand Down Expand Up @@ -113,10 +169,21 @@ def call(self, inputs: TFBaseModelOutputWithPoolingAndCrossAttentions):
class PrepareTransformerInputs(tf.keras.layers.Layer):
"""Prepare the dictionary of inputs expected by the transformer layer"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(self, inputs: tf.Tensor) -> Dict[str, tf.Tensor]:
mask = None
if getattr(inputs, "_keras_mask", None) is not None and isinstance(
inputs._keras_mask, tf.RaggedTensor
):
mask = inputs._keras_mask.to_tensor()
if isinstance(inputs, tf.RaggedTensor):
# convert to a dense tensor as HF transformers do not support ragged tensors
inputs = inputs.to_tensor()
if mask is not None:
inputs._keras_mask = mask
return {"inputs_embeds": inputs}


Expand Down
67 changes: 56 additions & 11 deletions merlin/models/tf/transforms/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,22 +621,67 @@ def from_config(cls, config):
return cls(schema, target, **config)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class SequenceMaskLastInference(Block):
def call(self, inputs, training=False, testing=False):
self.inference_mode = not training and not testing
if self.inference_mode:
# Extending sequences in one position by copying the last embedding
repeat = inputs[:, -1:, :]
# repeat = tf.expand_dims(repeat, 1)
inputs = tf.concat([inputs, repeat], axis=1)
return inputs

def compute_mask(self, inputs, mask=None):
"""Selects (masks) the nex position after the
last valid (non-padded) position of the sequential targets
to be predicted.
This method is called by Keras after call()
and returns the mask that is going to be assigned
to the input tensors, being accessible
by tensor._keras_mask
"""

targets_mask = None
if self.inference_mode:
if isinstance(inputs, tf.RaggedTensor):
row_lengths = inputs.row_lengths(1) + 1
max_seq_length = tf.cast(tf.reduce_max(row_lengths), tf.int32)

padding_mask = tf.sequence_mask(row_lengths)
targets_mask = tf.ragged.boolean_mask(
tf.cast(tf.one_hot(row_lengths - 1, max_seq_length), tf.bool), padding_mask
)

return targets_mask


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class ReplaceMaskedEmbeddings(Block):
"""Takes a 3D input tensor (batch size x seq. length x embedding dim) and replaces
by a dummy trainable single embedding at the positions to be masked.
This block looks for the Keras mask (`._keras_mask`) in the following order:
1. Checks if the input tensor has a mask
2. Checks if there is a single target and if it has a mask
3. If there are multiple targets (dict) returns the mask of the target
that matches the first 2 dims of the input
This is useful to be used when PredictMasked() transformation is used in
the Loader, which randomly selects some targets to be predicted and uses
Keras Masking to cascade the `_keras_mask`. By replacing input embeddings
at masked positions we avoid target leakage when training models with
Masked Language Modeling (BERT-like)
by a dummy trainable single embedding at the positions to be masked.
This block looks for the Keras mask (`._keras_mask`) in the following order:
1. Checks if the input tensor has a mask
2. Checks if there is a single target and if it has a mask
3. If there are multiple targets (dict) returns the mask of the target
that matches the first 2 dims of the input
This is useful to be used when PredictMasked() transformation is used in
the Loader, which randomly selects some targets to be predicted and uses
Keras Masking to cascade the `_keras_mask`. By replacing input embeddings
at masked positions we avoid target leakage when training models with
Masked Language Modeling (BERT-like)
**Note:** To support inference, the input sequence and its corresponding mask should be
extended by one position at the end to account for the next-item (`target`) position.
To do this, you should set `SequenceMaskLastInference` as a pre-layer of
`ReplaceMaskedEmbeddings()` using the sequential-block:
```mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()])```
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def build(self, input_shape):
self.hidden_size = input_shape[-1]
if self.hidden_size is None:
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/tf/transformers/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,13 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase
seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None
),
),
BertBlock(d_model=48, n_head=8, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()),
BertBlock(
d_model=48,
n_head=8,
n_layer=2,
pre=mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()]),
post="inference_hidden_state",
),
mm.CategoricalOutput(
seq_schema.select_by_name(target),
default_loss="categorical_crossentropy",
Expand All @@ -308,10 +314,9 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase
metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=seq_mask_last)
assert len(metrics) > 0

# Get predictions for next-item position
predictions = model.predict(loader, batch_size=8, steps=1)
# TODO: Decide what should be the output of predictions for MLM (currently it predicts for all
# positions of the sequence, but typically you want a single next-item prediction)
assert predictions.shape == (8, 4, 51997)
assert predictions.shape == (8, 51997)


@pytest.mark.parametrize("run_eagerly", [True, False])
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/tf/transforms/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_replace_masked_input_embeddings_no_target():
targets = None

masked_embeddings = mm.ReplaceMaskedEmbeddings()
output = masked_embeddings(item_id_emb_seq, targets=targets)
output = masked_embeddings(item_id_emb_seq, targets=targets, training=True)
# Checks that no input embedding was replaced, as there was no masking defined
tf.Assert(tf.logical_not(tf.reduce_all(output == item_id_emb_seq)), [])

Expand All @@ -262,7 +262,7 @@ def test_not_replace_unmasked_sequence_embeddings():
targets = tf.random.uniform((8, 10), dtype=tf.float32)

masked_embeddings = mm.ReplaceMaskedEmbeddings()
output = masked_embeddings(item_id_emb_seq, targets=targets)
output = masked_embeddings(item_id_emb_seq, targets=targets, training=True)
# Checks that no input embedding was replaced, as there was no masking defined
tf.Assert(tf.reduce_all(output == item_id_emb_seq), [])

Expand All @@ -275,7 +275,7 @@ def test_replace_masked_input_2d_embeddings_incompatible_2d_mask():
masked_embeddings = mm.ReplaceMaskedEmbeddings()

with pytest.raises(Exception) as exc_info:
_ = masked_embeddings(item_id_emb_seq)
_ = masked_embeddings(item_id_emb_seq, training=True)
assert "The inputs and mask need to be compatible" in str(exc_info.value)


Expand All @@ -287,7 +287,7 @@ def test_replace_masked_input_2d_embeddings_incompatible_ragged_2d_mask():
masked_embeddings = mm.ReplaceMaskedEmbeddings()

with pytest.raises(Exception) as exc_info:
_ = masked_embeddings(item_id_emb_seq)
_ = masked_embeddings(item_id_emb_seq, training=True)
assert "The inputs and mask need to be compatible" in str(exc_info.value)


Expand Down

0 comments on commit df84c81

Please sign in to comment.