Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Turn BidirectionalLM into a more-general LanguageModel class #2264

Merged
merged 49 commits into from Jan 8, 2019
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
85e2933
Change BidirectionalLM to more-general ShuffledSentenceLM
nelson-liu Jan 2, 2019
1ce937c
Remove extraneous tests (they come from superclass)
nelson-liu Jan 2, 2019
12070b2
Convert BidirectionalLMTokenEmbedder to ShuffledSentenceLMTokenEmbedder
nelson-liu Jan 2, 2019
2fbddc6
Edit fixture archive model type to shuffled_sentence_language_model, …
nelson-liu Jan 2, 2019
bad8142
Remove empty temp model.tar.gz
nelson-liu Jan 2, 2019
66dccd3
Fix lint
nelson-liu Jan 2, 2019
e2ee301
Smooth prose in ShuffledSentenceLM docs
nelson-liu Jan 2, 2019
76580ce
Rename test classes in shuffled_sentence_lm_test.py
nelson-liu Jan 2, 2019
11116c5
Test bidirectionality of LM and contextualizer match, in bidirectiona…
nelson-liu Jan 2, 2019
a9f0089
Test unidirectional shuffled sentence lm
nelson-liu Jan 2, 2019
40131bb
Fix lint
nelson-liu Jan 2, 2019
0f86a7f
Fix docs
nelson-liu Jan 2, 2019
833725e
Fix odd test failure in simple_tagger_test
nelson-liu Jan 2, 2019
2ba2a21
Fix alphabetical order in model docs
nelson-liu Jan 2, 2019
1ad77b7
Merge branch 'master' into unidirectional_lm
nelson-liu Jan 3, 2019
1a6ff85
Properly deprecate BidirectionalLM* classes
nelson-liu Jan 3, 2019
7269065
Merge branch 'master' into unidirectional_lm
nelson-liu Jan 3, 2019
ba75e0d
Add deprecation notices to docstrings
nelson-liu Jan 3, 2019
82d572c
Fix lint
nelson-liu Jan 3, 2019
2fd730c
Convert bidirectional_lm.jsonnet to bidirectional_shuffled_sentence_l…
nelson-liu Jan 3, 2019
f23d081
Remove ShuffledSentence prefix from classes and filenames
nelson-liu Jan 3, 2019
81d8efd
Fix docs and training config
nelson-liu Jan 3, 2019
de6d4fb
Fix whitespace in bidirectional_lm docstring
nelson-liu Jan 3, 2019
228443f
Don't deprecate BidirectionalLM
nelson-liu Jan 3, 2019
53e4c2d
fix lint
nelson-liu Jan 3, 2019
cb315cf
Make unsampled tests actually test unsampled case
nelson-liu Jan 3, 2019
54d455f
Add test for BidirectionalLM
nelson-liu Jan 4, 2019
2b072b5
Fix lint
nelson-liu Jan 4, 2019
76e2667
Fix doc and lint in language_model.py
nelson-liu Jan 4, 2019
55dc4c4
Make backward loss None in completely-masked case
nelson-liu Jan 4, 2019
d67f3ba
Fix bidirectional lm token embedder
nelson-liu Jan 4, 2019
16bf1e4
Don't deprecate bidirectional lm token embedder
nelson-liu Jan 4, 2019
b6b7d7c
Remove deprecated from docstring
nelson-liu Jan 4, 2019
6da6985
Deduplicate jsonnet config for unidirectional unsampled test fixture
nelson-liu Jan 4, 2019
036596d
Deduplicate bidirectional lm test fixtures
nelson-liu Jan 4, 2019
57a07f5
Test unsampled bidirectional lm
nelson-liu Jan 4, 2019
08b758b
Fix lint
nelson-liu Jan 4, 2019
5af15c7
Refactor and modularize bidirectional_lm / language_model tests
nelson-liu Jan 4, 2019
dbb9aa7
Add test for bidirectional_language_model_token_embedder
nelson-liu Jan 4, 2019
940d4f8
Fix lint
nelson-liu Jan 4, 2019
3995513
Don't remove bidirectional-language-model name
nelson-liu Jan 4, 2019
1151f4d
Refactor loop in LM loss calculation
nelson-liu Jan 4, 2019
2772814
Merge branch 'master' into unidirectional_lm
nelson-liu Jan 4, 2019
b38513d
Fix min count for vocab generation in bilm config
nelson-liu Jan 6, 2019
f6750d4
Merge branch 'master' into unidirectional_lm
nelson-liu Jan 6, 2019
cdb0d24
Merge branch 'master' into unidirectional_lm
nelson-liu Jan 6, 2019
ba5b149
Merge branch 'master' into unidirectional_lm
nelson-liu Jan 7, 2019
7b24424
Merge branch 'master' into unidirectional_lm
nelson-liu Jan 8, 2019
3d3c1fe
Update references to bidirectional_lm.jsonnet to bidirectional_langua…
nelson-liu Jan 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Expand Up @@ -27,3 +27,4 @@
from allennlp.models.bimpm import BiMpm
from allennlp.models.graph_parser import GraphParser
from allennlp.models.bidirectional_lm import BidirectionalLanguageModel
from allennlp.models.language_model import LanguageModel
265 changes: 14 additions & 251 deletions allennlp/models/bidirectional_lm.py
@@ -1,57 +1,19 @@
from typing import Dict, List, Tuple, Union
from typing import Union

import torch
import numpy as np

from allennlp.common.checks import ConfigurationError
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.language_model import LanguageModel
from allennlp.models.model import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.modules.sampled_softmax_loss import SampledSoftmaxLoss
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.nn.util import get_text_field_mask
from allennlp.nn import InitializerApplicator


class _SoftmaxLoss(torch.nn.Module):
"""
Given some embeddings and some targets, applies a linear layer
to create logits over possible words and then returns the
negative log likelihood.
"""
def __init__(self,
num_words: int,
embedding_dim: int) -> None:
super().__init__()

# TODO(joelgrus): implement tie_embeddings (maybe)
self.tie_embeddings = False

self.softmax_w = torch.nn.Parameter(
torch.randn(embedding_dim, num_words) / np.sqrt(embedding_dim)
)
self.softmax_b = torch.nn.Parameter(torch.zeros(num_words))

def forward(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# pylint: disable=arguments-differ
# embeddings is size (n, embedding_dim)
# targets is (batch_size, ) with the correct class id
# Does not do any count normalization / divide by batch size
probs = torch.nn.functional.log_softmax(
torch.matmul(embeddings, self.softmax_w) + self.softmax_b,
dim=-1
)

return torch.nn.functional.nll_loss(probs, targets.long(), reduction="sum")


@Model.register('bidirectional-language-model',
deprecation_message=('The "bidirectional-language-model" name was '
'deprecated in version 0.8 and will be removed'
'in version 0.10 . '
'Use "bidirectional_language_model" instead.'))
'deprecated in version 0.8. Use '
'"bidirectional_language_model" instead.'))
@Model.register('bidirectional_language_model')
class BidirectionalLanguageModel(Model):
class BidirectionalLanguageModel(LanguageModel):
"""
The ``BidirectionalLanguageModel`` applies a bidirectional "contextualizing"
``Seq2SeqEncoder`` to uncontextualized embeddings, using a ``SoftmaxLoss``
Expand Down Expand Up @@ -94,211 +56,12 @@ def __init__(self,
num_samples: int = None,
sparse_embeddings: bool = False,
initializer: InitializerApplicator = None) -> None:
nelson-liu marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(vocab)
self._text_field_embedder = text_field_embedder

if not contextualizer.is_bidirectional():
raise ConfigurationError("contextualizer must be bidirectional")

self._contextualizer = contextualizer
# The dimension for making predictions just in the forward
# (or backward) direction.
self._forward_dim = contextualizer.get_output_dim() // 2

# TODO(joelgrus): more sampled softmax configuration options, as needed.
if num_samples is not None:
self._softmax_loss = SampledSoftmaxLoss(num_words=vocab.get_vocab_size(),
embedding_dim=self._forward_dim,
num_samples=num_samples,
sparse=sparse_embeddings)
else:
self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
embedding_dim=self._forward_dim)

# TODO(brendanr): Output perplexity here. e^loss
self.register_buffer('_last_average_loss', torch.zeros(1))

if dropout:
self._dropout = torch.nn.Dropout(dropout)
else:
self._dropout = lambda x: x

self._loss_scale = loss_scale
if initializer is not None:
initializer(self)

def _get_target_token_embedding(self,
token_embeddings: torch.Tensor,
mask: torch.Tensor,
direction: int) -> torch.Tensor:
# Need to shift the mask in the correct direction
zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte()
if direction == 0:
# forward direction, get token to right
shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
else:
shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(-1, self._forward_dim)

def _compute_loss(self,
lm_embeddings: torch.Tensor,
token_embeddings: torch.Tensor,
forward_targets: torch.Tensor,
backward_targets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# lm_embeddings is shape (batch_size, timesteps, dim * 2)
# forward_targets, backward_targets are shape (batch_size, timesteps)
# masked with 0
forward_embeddings, backward_embeddings = lm_embeddings.chunk(2, -1)
losses: List[torch.Tensor] = []
for idx, embedding, targets in ((0, forward_embeddings, forward_targets),
(1, backward_embeddings, backward_targets)):
mask = targets > 0
# we need to subtract 1 to undo the padding id since the softmax
# does not include a padding dimension

# shape (batch_size * timesteps, )
non_masked_targets = targets.masked_select(mask) - 1

# shape (batch_size * timesteps, embedding_dim)
non_masked_embedding = embedding.masked_select(
mask.unsqueeze(-1)
).view(-1, self._forward_dim)
# note: need to return average loss across forward and backward
# directions, but total sum loss across all batches.
# Assuming batches include full sentences, forward and backward
# directions have the same number of samples, so sum up loss
# here then divide by 2 just below
if not self._softmax_loss.tie_embeddings or not self._use_character_inputs:
losses.append(self._softmax_loss(non_masked_embedding, non_masked_targets))
else:
# we also need the token embeddings corresponding to the
# the targets
raise NotImplementedError("This requires SampledSoftmaxLoss, which isn't implemented yet.")
# pylint: disable=unreachable
non_masked_token_embedding = self._get_target_token_embedding(token_embeddings, mask, idx)
losses.append(self._softmax(non_masked_embedding,
non_masked_targets,
non_masked_token_embedding))

return losses[0], losses[1]

def delete_softmax(self) -> None:
"""
Remove the softmax weights. Useful for saving memory when calculating the loss
is not necessary, e.g. in an embedder.
"""
self._softmax_loss = None

def num_layers(self) -> int:
"""
Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
the non-contextual layer.
"""
if hasattr(self._contextualizer, 'num_layers'):
return self._contextualizer.num_layers + 1
else:
raise NotImplementedError(f"Contextualizer of type {type(self._contextualizer)} " +
"does not report how many layers it has.")

def forward(self, # type: ignore
source: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
"""
Computes the averaged forward and backward LM loss from the batch.

By convention, the input dict is required to have at least a ``"tokens"``
entry that's the output of a ``SingleIdTokenIndexer``, which is used
to compute the language model targets.

Parameters
----------
tokens: ``torch.Tensor``, required.
The output of ``Batch.as_tensor_dict()`` for a batch of sentences.

Returns
-------
Dict with keys:

``'loss'``: ``torch.Tensor``
averaged forward/backward negative log likelihood
``'forward_loss'``: ``torch.Tensor``
forward direction negative log likelihood
``'backward_loss'``: ``torch.Tensor``
backward direction negative log likelihood
``'lm_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]``
(batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
list of all layers. No dropout applied.
``'noncontextual_token_embeddings'``: ``torch.Tensor``
(batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
representations
``'mask'``: ``torch.Tensor``
(batch_size, timesteps) mask for the embeddings
"""
# pylint: disable=arguments-differ
mask = get_text_field_mask(source)

# shape (batch_size, timesteps, embedding_size)
embeddings = self._text_field_embedder(source)

# Either the top layer or all layers.
contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = self._contextualizer(
embeddings, mask
)

return_dict = {}

# If we have target tokens, calculate the loss.
token_ids = source.get("tokens")
if token_ids is not None:
assert isinstance(contextual_embeddings, torch.Tensor)

# Use token_ids to compute targets
forward_targets = torch.zeros_like(token_ids)
backward_targets = torch.zeros_like(token_ids)
forward_targets[:, 0:-1] = token_ids[:, 1:]
backward_targets[:, 1:] = token_ids[:, 0:-1]

# add dropout
contextual_embeddings_with_dropout = self._dropout(contextual_embeddings)

# compute softmax loss
forward_loss, backward_loss = self._compute_loss(contextual_embeddings_with_dropout,
embeddings,
forward_targets,
backward_targets)

num_targets = torch.sum((forward_targets > 0).long())
if num_targets > 0:
average_loss = 0.5 * (forward_loss + backward_loss) / num_targets.float()
else:
average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable
# this is stored to compute perplexity if needed
self._last_average_loss[0] = average_loss.detach().item()

if num_targets > 0:
# loss is directly minimized
if self._loss_scale == 'n_samples':
scale_factor = num_targets.float()
else:
scale_factor = self._loss_scale

return_dict.update({
'loss': average_loss * scale_factor,
'forward_loss': forward_loss * scale_factor / num_targets.float(),
'backward_loss': backward_loss * scale_factor / num_targets.float()
})
else:
# average_loss zero tensor, return it for all
return_dict.update({
'loss': average_loss,
'forward_loss': average_loss,
'backward_loss': average_loss
})

return_dict.update({
# Note: These embeddings do not have dropout applied.
'lm_embeddings': contextual_embeddings,
'noncontextual_token_embeddings': embeddings,
'mask': mask
})

return return_dict
super().__init__(vocab=vocab,
text_field_embedder=text_field_embedder,
contextualizer=contextualizer,
dropout=dropout,
loss_scale=loss_scale,
num_samples=num_samples,
sparse_embeddings=sparse_embeddings,
bidirectional=True,
initializer=initializer)