Skip to content

Commit

Permalink
Allow to pass pretrained_file in embedding extension (with tests).
Browse files Browse the repository at this point in the history
  • Loading branch information
HarshTrivedi committed Jan 17, 2019
1 parent 996bdeb commit 71358d1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
26 changes: 20 additions & 6 deletions allennlp/modules/token_embedders/embedding.py
Expand Up @@ -147,10 +147,14 @@ def forward(self, inputs): # pylint: disable=arguments-differ
return embedded

@overrides
def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
def extend_vocab(self,
extended_vocab: Vocabulary,
vocab_namespace: str = None,
pretrained_file: str = None):
"""
Extends the embedding matrix according to the extended vocabulary.
Extended weight would be initialized with xavier uniform.
If pretrained_file is available, it will be used for extented weight
or else it would be initialized with xavier uniform.
Parameters
----------
Expand All @@ -162,6 +166,10 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
can pass it. If not passed, it will check if vocab_namespace used at the
time of ``Embedding`` construction is available. If so, this namespace
will be used or else default 'tokens' namespace will be used.
pretrained_file : str, (optional, default=None)
A file containing pretrained embeddings can be specified here. It can be
the path to a local file or an URL of a (cached) remote file. Check format
details in ``from_params`` of ``Embedding`` class.
"""
# Caveat: For allennlp v0.8.1 and below, we weren't storing vocab_namespace as an attribute,
# knowing which is necessary at time of embedding vocab extension. So old archive models are
Expand All @@ -172,11 +180,17 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
vocab_namespace = "tokens"
logging.warning("No vocab_namespace provided to Embedder.extend_vocab. Defaulting to 'tokens'.")

extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
extra_num_embeddings = extended_num_embeddings - self.num_embeddings
embedding_dim = self.weight.data.shape[-1]
extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim)
torch.nn.init.xavier_uniform_(extra_weight)
if not pretrained_file:
extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
extra_num_embeddings = extended_num_embeddings - self.num_embeddings
extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim)
torch.nn.init.xavier_uniform_(extra_weight)
else:
whole_weight = _read_pretrained_embeddings_file(pretrained_file, embedding_dim,
extended_vocab, vocab_namespace)
extra_weight = whole_weight[self.num_embeddings:, :]

extended_weight = torch.cat([self.weight.data, extra_weight], dim=0)
self.weight = torch.nn.Parameter(extended_weight, requires_grad=self.weight.requires_grad)

Expand Down
28 changes: 28 additions & 0 deletions allennlp/tests/modules/token_embedders/embedding_test.py
Expand Up @@ -281,3 +281,31 @@ def test_embedding_vocab_extension_without_stored_namespace(self):
extended_weight = embedder.weight
assert extended_weight.shape[0] == 5
assert torch.all(extended_weight[:4, :] == original_weight[:4, :])

def test_embedding_vocab_extension_works_with_pretrained_embedding_file(self):
vocab = Vocabulary()
vocab.add_token_to_namespace('word1')
vocab.add_token_to_namespace('word2')

embeddings_filename = str(self.TEST_DIR / "embeddings2.gz")
with gzip.open(embeddings_filename, 'wb') as embeddings_file:
embeddings_file.write("word3 0.5 0.3 -6.0\n".encode('utf-8'))
embeddings_file.write("word4 1.0 2.3 -1.0\n".encode('utf-8'))
embeddings_file.write("word2 0.1 0.4 -4.0\n".encode('utf-8'))
embeddings_file.write("word1 1.0 2.3 -1.0\n".encode('utf-8'))

embedding_params = Params({"vocab_namespace": "tokens", "embedding_dim": 3,
"pretrained_file": embeddings_filename})
embedder = Embedding.from_params(vocab, embedding_params)
original_weight = embedder.weight

assert tuple(original_weight.size()) == (4, 3) # 4 because of padding and OOV

vocab.add_token_to_namespace('word3')
embedder.extend_vocab(vocab, pretrained_file=embeddings_filename) # default namespace
extended_weight = embedder.weight

extended_weight = embedder.weight
assert extended_weight.shape[0] == 5
assert torch.all(original_weight[:4, :] == extended_weight[:4, :])
assert torch.all(extended_weight[4, :] == torch.tensor([0.5, 0.3, -6.0]))

0 comments on commit 71358d1

Please sign in to comment.