Skip to content
Permalink
Browse files

Bidirectional LM Embedder (#2138)

- A token embedder for transformer based bidirectional LMs.
- Includes demo model: constituency_parser_transformer_elmo.jsonnet
- New tests for training a transformer based bidirectional language model.
  • Loading branch information...
brendan-ai2 committed Dec 21, 2018
1 parent 0889a0d commit ce060badd12d3047e3af81cf97d0b62805e397e5
@@ -97,7 +97,11 @@ def tokens_to_indices(self,
tokens: List[Token],
vocabulary: Vocabulary,
index_name: str) -> Dict[str, List[List[int]]]:
# TODO(brendanr): Retain the token to index mappings in the vocabulary and remove this
# pylint pragma. See:
# https://github.com/allenai/allennlp/blob/master/allennlp/data/token_indexers/wordpiece_indexer.py#L113
# pylint: disable=unused-argument

texts = [token.text for token in tokens]

if any(text is None for text in texts):
@@ -37,6 +37,7 @@ def __init__(self,
lowercase_characters: bool = False,
start_tokens: List[str] = None,
end_tokens: List[str] = None) -> None:
# TODO(brendanr): Add length truncation.
self._byte_encoding = byte_encoding
self._lowercase_characters = lowercase_characters
self._start_tokens = start_tokens or []
@@ -9,7 +9,7 @@
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.modules.sampled_softmax_loss import SampledSoftmaxLoss
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.nn.util import get_text_field_mask, remove_sentence_boundaries
from allennlp.nn.util import get_text_field_mask
from allennlp.nn import InitializerApplicator


@@ -67,15 +67,12 @@ class BidirectionalLanguageModel(Model):
Used to "contextualize" the embeddings. As described above,
this encoder must not cheat by peeking ahead.
dropout: ``float``, optional (default: None)
If specified, dropout is applied to the contextualized embeddings.
If specified, dropout is applied to the contextualized embeddings before computation of
the softmax. The contextualized embeddings themselves are returned without dropout.
loss_scale: ``Union[float, str]``, optional (default: 1.0)
This scaling factor is applied to the average language model loss.
You can also specify ``"n_samples"`` in which case we compute total
loss across all predictions.
remove_bos_eos: ``bool``, optional (default: True)
Typically the provided token indexes will be augmented with
begin-sentence and end-sentence tokens. If this flag is True
the corresponding embeddings will be removed from the return values.
num_samples: ``int``, optional (default: None)
If provided, the model will use ``SampledSoftmaxLoss``
with the specified number of samples. Otherwise, it will use
@@ -89,7 +86,6 @@ def __init__(self,
contextualizer: Seq2SeqEncoder,
dropout: float = None,
loss_scale: Union[float, str] = 1.0,
remove_bos_eos: bool = True,
num_samples: int = None,
sparse_embeddings: bool = False,
initializer: InitializerApplicator = None) -> None:
@@ -123,7 +119,6 @@ def __init__(self,
self._dropout = lambda x: x

self._loss_scale = loss_scale
self._remove_bos_eos = remove_bos_eos
if initializer is not None:
initializer(self)

@@ -182,6 +177,24 @@ def _compute_loss(self,

return losses[0], losses[1]

def delete_softmax(self) -> None:
"""
Remove the softmax weights. Useful for saving memory when calculating the loss
is not necessary, e.g. in an embedder.
"""
self._softmax_loss = None

def num_layers(self) -> int:
"""
Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
the non-contextual layer.
"""
if hasattr(self._contextualizer, 'num_layers'):
return self._contextualizer.num_layers + 1
else:
raise NotImplementedError(f"Contextualizer of type {type(self._contextualizer)} " +
"does not report how many layers it has.")

def forward(self, # type: ignore
source: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
"""
@@ -191,10 +204,6 @@ def forward(self, # type: ignore
entry that's the output of a ``SingleIdTokenIndexer``, which is used
to compute the language model targets.
If the model was instantiated with ``remove_bos_eos=True``,
then it is expected that each of the input sentences was augmented with
begin-sentence and end-sentence tokens.
Parameters
----------
tokens: ``torch.Tensor``, required.
@@ -210,73 +219,80 @@ def forward(self, # type: ignore
forward direction negative log likelihood
``'backward_loss'``: ``torch.Tensor``
backward direction negative log likelihood
``'lm_embeddings'``: ``torch.Tensor``
(batch_size, timesteps, embed_dim) tensor of top layer contextual representations
``'lm_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]``
(batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
list of all layers. No dropout applied.
``'noncontextual_token_embeddings'``: ``torch.Tensor``
(batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
representations
``'mask'``: ``torch.Tensor``
(batch_size, timesteps) mask for the embeddings
"""
# pylint: disable=arguments-differ
mask = get_text_field_mask(source)

# We must have token_ids so that we can compute targets
token_ids = source.get("tokens")
if token_ids is None:
raise ConfigurationError("Your data must have a 'tokens': SingleIdTokenIndexer() "
"in order to use the BidirectionalLM")

# Use token_ids to compute targets
forward_targets = torch.zeros_like(token_ids)
backward_targets = torch.zeros_like(token_ids)
forward_targets[:, 0:-1] = token_ids[:, 1:]
backward_targets[:, 1:] = token_ids[:, 0:-1]

# shape (batch_size, timesteps + 2, embedding_size)
# shape (batch_size, timesteps, embedding_size)
embeddings = self._text_field_embedder(source)

contextual_embeddings = self._contextualizer(embeddings, mask)

# add dropout
contextual_embeddings = self._dropout(contextual_embeddings)
# Either the top layer or all layers.
contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = self._contextualizer(
embeddings, mask
)

# compute softmax loss
forward_loss, backward_loss = self._compute_loss(contextual_embeddings,
embeddings,
forward_targets,
backward_targets)
return_dict = {}

num_targets = torch.sum((forward_targets > 0).long())
if num_targets > 0:
average_loss = 0.5 * (forward_loss + backward_loss) / num_targets.float()
else:
average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable
# this is stored to compute perplexity if needed
self._last_average_loss[0] = average_loss.detach().item()

if num_targets > 0:
# loss is directly minimized
if self._loss_scale == 'n_samples':
scale_factor = num_targets.float()
# If we have target tokens, calculate the loss.
token_ids = source.get("tokens")
if token_ids is not None:
assert isinstance(contextual_embeddings, torch.Tensor)

# Use token_ids to compute targets
forward_targets = torch.zeros_like(token_ids)
backward_targets = torch.zeros_like(token_ids)
forward_targets[:, 0:-1] = token_ids[:, 1:]
backward_targets[:, 1:] = token_ids[:, 0:-1]

# add dropout
contextual_embeddings_with_dropout = self._dropout(contextual_embeddings)

# compute softmax loss
forward_loss, backward_loss = self._compute_loss(contextual_embeddings_with_dropout,
embeddings,
forward_targets,
backward_targets)

num_targets = torch.sum((forward_targets > 0).long())
if num_targets > 0:
average_loss = 0.5 * (forward_loss + backward_loss) / num_targets.float()
else:
scale_factor = self._loss_scale

return_dict = {
'loss': average_loss * scale_factor,
'forward_loss': forward_loss * scale_factor / num_targets.float(),
'backward_loss': backward_loss * scale_factor / num_targets.float()
}
else:
# average_loss zero tensor, return it for all
return_dict = {
'loss': average_loss,
'forward_loss': average_loss,
'backward_loss': average_loss
}

if self._remove_bos_eos:
contextual_embeddings, mask = remove_sentence_boundaries(contextual_embeddings, mask)
average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable
# this is stored to compute perplexity if needed
self._last_average_loss[0] = average_loss.detach().item()

if num_targets > 0:
# loss is directly minimized
if self._loss_scale == 'n_samples':
scale_factor = num_targets.float()
else:
scale_factor = self._loss_scale

return_dict.update({
'loss': average_loss * scale_factor,
'forward_loss': forward_loss * scale_factor / num_targets.float(),
'backward_loss': backward_loss * scale_factor / num_targets.float()
})
else:
# average_loss zero tensor, return it for all
return_dict.update({
'loss': average_loss,
'forward_loss': average_loss,
'backward_loss': average_loss
})

return_dict.update({
# Note: These embeddings do not have dropout applied.
'lm_embeddings': contextual_embeddings,
'noncontextual_token_embeddings': embeddings,
'mask': mask
})

@@ -9,3 +9,5 @@
from allennlp.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder
from allennlp.modules.token_embedders.openai_transformer_embedder import OpenaiTransformerEmbedder
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder
from allennlp.modules.token_embedders.bidirectional_language_model_token_embedder import \
BidirectionalLanguageModelTokenEmbedder
Oops, something went wrong.

0 comments on commit ce060ba

Please sign in to comment.
You can’t perform that action at this time.