This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Constituency Parser Predictor and tests (#914)
* add constituency predictor, clean up some span padding in decode * regenerate test fixtures, address some regressions * batch test * fix up test, docs
- Loading branch information
Showing
24 changed files
with
230 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Tuple, List | ||
from overrides import overrides | ||
|
||
from allennlp.common.util import JsonDict, sanitize | ||
from allennlp.data import DatasetReader, Instance | ||
from allennlp.models import Model | ||
from allennlp.service.predictors.predictor import Predictor | ||
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter | ||
|
||
|
||
@Predictor.register('constituency-parser') | ||
class ConstituencyParserPredictor(Predictor): | ||
""" | ||
Wrapper for the :class:`~allennlp.models.SpanConstituencyParser` model. | ||
""" | ||
def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: | ||
super().__init__(model, dataset_reader) | ||
self._tokenizer = SpacyWordSplitter(language='en_core_web_sm') | ||
|
||
@overrides | ||
def _json_to_instance(self, json_dict: JsonDict) -> Tuple[Instance, JsonDict]: | ||
""" | ||
Expects JSON that looks like ``{"sentence": "..."}``. | ||
""" | ||
sentence_text = [token.text for token in self._tokenizer.split_words(json_dict["sentence"])] | ||
return self._dataset_reader.text_to_instance(sentence_text), {"sentence": sentence_text} | ||
|
||
@overrides | ||
def predict_json(self, inputs: JsonDict, cuda_device: int = -1) -> JsonDict: | ||
instance, return_dict = self._json_to_instance(inputs) | ||
outputs = self._model.forward_on_instance(instance, cuda_device) | ||
return_dict.update(outputs) | ||
|
||
# format the NLTK tree as a string on a single line. | ||
tree = return_dict.pop("trees") | ||
return_dict["trees"] = tree.pformat(margin=1000000) | ||
return sanitize(return_dict) | ||
|
||
@overrides | ||
def predict_batch_json(self, inputs: List[JsonDict], cuda_device: int = -1) -> List[JsonDict]: | ||
instances, return_dicts = zip(*self._batch_json_to_instances(inputs)) | ||
outputs = self._model.forward_on_instances(instances, cuda_device) | ||
for output, return_dict in zip(outputs, return_dicts): | ||
return_dict.update(output) | ||
# format the NLTK tree as a string on a single line. | ||
tree = return_dict.pop("trees") | ||
return_dict["trees"] = tree.pformat(margin=1000000) | ||
return sanitize(return_dicts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
59 changes: 0 additions & 59 deletions
59
tests/fixtures/bidaf/serialization/vocabulary/token_characters.txt
This file was deleted.
Oops, something went wrong.
53 changes: 53 additions & 0 deletions
53
tests/fixtures/constituency_parser/experiment_no_evalb.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
{ | ||
"dataset_reader":{ | ||
"type":"ptb_trees", | ||
"use_pos_tags": false | ||
}, | ||
"train_data_path": "tests/fixtures/data/example_ptb.trees", | ||
"validation_data_path": "tests/fixtures/data/example_ptb.trees", | ||
"model": { | ||
"type": "constituency_parser", | ||
"text_field_embedder": { | ||
"tokens": { | ||
"type": "embedding", | ||
"embedding_dim": 2, | ||
"trainable": true | ||
} | ||
}, | ||
"encoder": { | ||
"type": "lstm", | ||
"input_size": 2, | ||
"hidden_size": 4, | ||
"num_layers": 1 | ||
}, | ||
"feedforward": { | ||
"input_dim": 4, | ||
"num_layers": 1, | ||
"hidden_dims": 4, | ||
"activations": "relu" | ||
}, | ||
"span_extractor": { | ||
"type": "endpoint", | ||
"input_dim": 4 | ||
} | ||
}, | ||
|
||
"iterator": { | ||
"type": "bucket", | ||
"sorting_keys": [["tokens", "num_tokens"]], | ||
"padding_noise": 0.0, | ||
"batch_size" : 5 | ||
}, | ||
"trainer": { | ||
"num_epochs": 1, | ||
"grad_norm": 1.0, | ||
"patience": 500, | ||
"cuda_device": -1, | ||
"optimizer": { | ||
"type": "adadelta", | ||
"lr": 0.000001, | ||
"rho": 0.95 | ||
} | ||
} | ||
} | ||
|
Binary file not shown.
Binary file not shown.
9 changes: 9 additions & 0 deletions
9
tests/fixtures/constituency_parser/serialization/vocabulary/labels.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
NO-LABEL | ||
NP | ||
VP | ||
S | ||
ADVP | ||
VROOT | ||
SBAR | ||
PP | ||
ADJP |
2 changes: 2 additions & 0 deletions
2
tests/fixtures/constituency_parser/serialization/vocabulary/non_padded_namespaces.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*tags | ||
*labels |
42 changes: 42 additions & 0 deletions
42
tests/fixtures/constituency_parser/serialization/vocabulary/tokens.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
@@UNKNOWN@@ | ||
the | ||
to | ||
, | ||
UAL | ||
and | ||
other | ||
be | ||
him | ||
. | ||
Also | ||
because | ||
Chairman | ||
Stephen | ||
Wolf | ||
executives | ||
have | ||
joined | ||
pilots | ||
' | ||
bid | ||
board | ||
might | ||
forced | ||
exclude | ||
from | ||
its | ||
deliberations | ||
in | ||
order | ||
fair | ||
bidders | ||
That | ||
could | ||
cost | ||
chance | ||
influence | ||
outcome | ||
perhaps | ||
join | ||
winning | ||
bidder |
Binary file not shown.
Binary file not shown.
Binary file modified
BIN
+0 Bytes
(100%)
tests/fixtures/decomposable_attention/serialization/best.th
Binary file not shown.
Binary file modified
BIN
-31 Bytes
(100%)
tests/fixtures/decomposable_attention/serialization/model.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# pylint: disable=no-self-use,invalid-name | ||
from unittest import TestCase | ||
|
||
from allennlp.models.archival import load_archive | ||
from allennlp.service.predictors import Predictor | ||
|
||
|
||
class TestConstituencyParserPredictor(TestCase): | ||
def test_uses_named_inputs(self): | ||
inputs = { | ||
"sentence": "What a great test sentence.", | ||
} | ||
|
||
archive = load_archive('tests/fixtures/constituency_parser/serialization/model.tar.gz') | ||
predictor = Predictor.from_archive(archive, 'constituency-parser') | ||
result = predictor.predict_json(inputs) | ||
|
||
assert len(result["spans"]) == 21 # number of possible substrings of the sentence. | ||
assert len(result["class_probabilities"]) == 21 | ||
assert result["sentence"] == ["What", "a", "great", "test", "sentence", "."] | ||
assert isinstance(result["trees"], str) | ||
|
||
for class_distribution in result["class_probabilities"]: | ||
self.assertAlmostEqual(sum(class_distribution), 1.0, places=4) | ||
|
||
def test_batch_prediction(self): | ||
inputs = [ | ||
{"sentence": "What a great test sentence."}, | ||
{"sentence": "Here's another good, interesting one."} | ||
] | ||
|
||
archive = load_archive('tests/fixtures/constituency_parser/serialization/model.tar.gz') | ||
predictor = Predictor.from_archive(archive, 'constituency-parser') | ||
results = predictor.predict_batch_json(inputs) | ||
|
||
result = results[0] | ||
assert len(result["spans"]) == 21 # number of possible substrings of the sentence. | ||
assert len(result["class_probabilities"]) == 21 | ||
assert result["sentence"] == ["What", "a", "great", "test", "sentence", "."] | ||
assert isinstance(result["trees"], str) | ||
|
||
for class_distribution in result["class_probabilities"]: | ||
self.assertAlmostEqual(sum(class_distribution), 1.0, places=4) | ||
|
||
result = results[1] | ||
|
||
assert len(result["spans"]) == 36 # number of possible substrings of the sentence. | ||
assert len(result["class_probabilities"]) == 36 | ||
assert result["sentence"] == ["Here", "'s", "another", "good", ",", "interesting", "one", "."] | ||
assert isinstance(result["trees"], str) | ||
|
||
for class_distribution in result["class_probabilities"]: | ||
self.assertAlmostEqual(sum(class_distribution), 1.0, places=4) |