Skip to content
Permalink
Browse files

Fix more BasicTextFieldEmbedder warnings (#2114)

  • Loading branch information
schmmd committed Nov 29, 2018
1 parent f757f7a commit db0096f1a0ba4c99c05042482cfbc76b2ddcbaa9
Showing with 38 additions and 31 deletions.
  1. +38 −31 allennlp/tests/modules/text_field_embedders/basic_text_field_embedder_test.py
@@ -18,17 +18,19 @@ def setUp(self):
self.vocab.add_token_to_namespace("3")
self.vocab.add_token_to_namespace("4")
params = Params({
"words1": {
"type": "embedding",
"embedding_dim": 2
},
"words2": {
"type": "embedding",
"embedding_dim": 5
},
"words3": {
"type": "embedding",
"embedding_dim": 3
"token_embedders": {
"words1": {
"type": "embedding",
"embedding_dim": 2
},
"words2": {
"type": "embedding",
"embedding_dim": 5
},
"words3": {
"type": "embedding",
"embedding_dim": 3
}
}
})
self.token_embedder = BasicTextFieldEmbedder.from_params(vocab=self.vocab, params=params)
@@ -54,23 +56,25 @@ def test_forward_concats_resultant_embeddings(self):

def test_forward_works_on_higher_order_input(self):
params = Params({
"words": {
"type": "embedding",
"num_embeddings": 20,
"embedding_dim": 2,
},
"characters": {
"type": "character_encoding",
"embedding": {
"embedding_dim": 4,
"num_embeddings": 15,
},
"encoder": {
"type": "cnn",
"embedding_dim": 4,
"num_filters": 10,
"ngram_filter_sizes": [3],
"token_embedders": {
"words": {
"type": "embedding",
"num_embeddings": 20,
"embedding_dim": 2,
},
"characters": {
"type": "character_encoding",
"embedding": {
"embedding_dim": 4,
"num_embeddings": 15,
},
"encoder": {
"type": "cnn",
"embedding_dim": 4,
"num_filters": 10,
"ngram_filter_sizes": [3],
},
}
}
})
token_embedder = BasicTextFieldEmbedder.from_params(vocab=self.vocab, params=params)
@@ -105,7 +109,6 @@ def test_forward_runs_with_non_bijective_mapping(self):
token_embedder(inputs)

def test_old_from_params_new_from_params(self):

old_params = Params({
"words1": {
"type": "embedding",
@@ -121,8 +124,9 @@ def test_old_from_params_new_from_params(self):
}
})

# Allow loading the parameters in the old format
with pytest.warns(DeprecationWarning):
BasicTextFieldEmbedder.from_params(params=old_params, vocab=self.vocab)
old_embedder = BasicTextFieldEmbedder.from_params(params=old_params, vocab=self.vocab)

new_params = Params({
"token_embedders": {
@@ -141,5 +145,8 @@ def test_old_from_params_new_from_params(self):
}
})

token_embedder = BasicTextFieldEmbedder.from_params(params=new_params, vocab=self.vocab)
assert token_embedder(self.inputs).size() == (1, 4, 10)
# But also allow loading the parameters in the new format
new_embedder = BasicTextFieldEmbedder.from_params(params=new_params, vocab=self.vocab)
assert old_embedder._token_embedders.keys() == new_embedder._token_embedders.keys() #pylint: disable=protected-access

assert new_embedder(self.inputs).size() == (1, 4, 10)

0 comments on commit db0096f

Please sign in to comment.
You can’t perform that action at this time.