Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adding an optional projection layer to ElmoTokenEmbedder (#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-gardner committed Apr 12, 2018
1 parent 4c02d92 commit b72c838
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
9 changes: 5 additions & 4 deletions allennlp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
"""

from allennlp.modules.attention import Attention
from allennlp.modules.layer_norm import LayerNorm
from allennlp.modules.conditional_random_field import ConditionalRandomField
from allennlp.modules.elmo import Elmo
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.highway import Highway
from allennlp.modules.layer_norm import LayerNorm
from allennlp.modules.matrix_attention import MatrixAttention
from allennlp.modules.maxout import Maxout
from allennlp.modules.scalar_mix import ScalarMix
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder
from allennlp.modules.similarity_functions import SimilarityFunction
from allennlp.modules.span_pruner import SpanPruner
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.modules.token_embedders import TokenEmbedder
from allennlp.modules.scalar_mix import ScalarMix
from allennlp.modules.span_pruner import SpanPruner
from allennlp.modules.maxout import Maxout
6 changes: 6 additions & 0 deletions allennlp/modules/elmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __init__(self,
self.add_module('scalar_mix_{}'.format(k), scalar_mix)
self._scalar_mixes.append(scalar_mix)

def get_output_dim(self):
return self._elmo_lstm.get_output_dim()

def forward(self, # pylint: disable=arguments-differ
inputs: torch.Tensor) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -432,6 +435,9 @@ def __init__(self,
# Number of representation layers including context independent layer
self.num_layers = options['lstm']['n_layers'] + 1

def get_output_dim(self):
return 2 * self._token_embedder.get_output_dim()

def forward(self, # pylint: disable=arguments-differ
inputs: torch.Tensor) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
"""
Expand Down
33 changes: 27 additions & 6 deletions allennlp/modules/token_embedders/elmo_token_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from allennlp.common import Params
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.modules.elmo import Elmo
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.data import Vocabulary


Expand All @@ -25,15 +26,20 @@ class ElmoTokenEmbedder(TokenEmbedder):
Should we apply layer normalization (passed to ``ScalarMix``)?
dropout : ``float``, optional.
The dropout value to be applied to the ELMo representations.
requires_grad: ``bool``, optional
requires_grad : ``bool``, optional
If True, compute gradient of ELMo parameters for fine tuning.
projection_dim : ``int``, optional
If given, we will project the ELMo embedding down to this dimension. We recommend that you
try using ELMo with a lot of dropout and no projection first, but we have found a few cases
where projection helps (particulary where there is very limited training data).
"""
def __init__(self,
options_file: str,
weight_file: str,
do_layer_norm: bool = False,
dropout: float = 0.5,
requires_grad: bool = False) -> None:
requires_grad: bool = False,
projection_dim: int = None) -> None:
super(ElmoTokenEmbedder, self).__init__()

self._elmo = Elmo(options_file,
Expand All @@ -42,10 +48,13 @@ def __init__(self,
do_layer_norm=do_layer_norm,
dropout=dropout,
requires_grad=requires_grad)
if projection_dim:
self._projection = torch.nn.Linear(self._elmo.get_output_dim(), projection_dim)
else:
self._projection = None

def get_output_dim(self):
# pylint: disable=protected-access
return 2 * self._elmo._elmo_lstm._token_embedder.get_output_dim()
return self._elmo.get_output_dim()

def forward(self, inputs: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ
"""
Expand All @@ -60,7 +69,13 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: # pylint: disable=argum
``(batch_size, timesteps, embedding_dim)``
"""
elmo_output = self._elmo(inputs)
return elmo_output['elmo_representations'][0]
elmo_representations = elmo_output['elmo_representations'][0]
if self._projection:
projection = self._projection
for _ in range(elmo_representations.dim() - 2):
projection = TimeDistributed(projection)
elmo_representations = projection(elmo_representations)
return elmo_representations

@classmethod
def from_params(cls, vocab: Vocabulary, params: Params) -> 'ElmoTokenEmbedder':
Expand All @@ -71,5 +86,11 @@ def from_params(cls, vocab: Vocabulary, params: Params) -> 'ElmoTokenEmbedder':
requires_grad = params.pop('requires_grad', False)
do_layer_norm = params.pop_bool('do_layer_norm', False)
dropout = params.pop_float("dropout", 0.5)
projection_dim = params.pop_int("projection_dim", None)
params.assert_empty(cls.__name__)
return cls(options_file, weight_file, do_layer_norm, dropout, requires_grad=requires_grad)
return cls(options_file=options_file,
weight_file=weight_file,
do_layer_norm=do_layer_norm,
dropout=dropout,
requires_grad=requires_grad,
projection_dim=projection_dim)
30 changes: 30 additions & 0 deletions tests/modules/token_embedders/elmo_token_embedder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
import os
import tarfile

import torch
from torch.autograd import Variable

from allennlp.commands.train import train_model
from allennlp.common import Params
from allennlp.common.testing import ModelTestCase
from allennlp.data.dataset import Batch
from allennlp.modules.token_embedders import ElmoTokenEmbedder


class TestElmoTokenEmbedder(ModelTestCase):
def setUp(self):
Expand Down Expand Up @@ -63,3 +68,28 @@ def test_file_archiving(self):
for key, original_filename in files_to_archive.items():
new_filename = os.path.join(unarchive_dir, "fta", key)
assert filecmp.cmp(original_filename, new_filename)

def test_forward_works_with_projection_layer(self):
params = Params({
'options_file': 'tests/fixtures/elmo/options.json',
'weight_file': 'tests/fixtures/elmo/lm_weights.hdf5',
'projection_dim': 20
})
word1 = [0] * 50
word2 = [0] * 50
word1[0] = 6
word1[1] = 5
word1[2] = 4
word1[3] = 3
word2[0] = 3
word2[1] = 2
word2[2] = 1
word2[3] = 0
embedding_layer = ElmoTokenEmbedder.from_params(vocab=None, params=params)
input_tensor = Variable(torch.LongTensor([[word1, word2]]))
embedded = embedding_layer(input_tensor).data.numpy()
assert embedded.shape == (1, 2, 20)

input_tensor = Variable(torch.LongTensor([[[word1]]]))
embedded = embedding_layer(input_tensor).data.numpy()
assert embedded.shape == (1, 1, 1, 20)

0 comments on commit b72c838

Please sign in to comment.