From d709e59ff3f54e137ea473d411056f68c197c266 Mon Sep 17 00:00:00 2001 From: Junlin Wang Date: Fri, 24 Apr 2020 17:27:57 -0700 Subject: [PATCH] modified the make_output_human_readable method in basic_classifier for allennlp-demo (#4038) * modified the make_output_human_readable method in basic_classifier, so that the outputdict can contain a tokens field for allennlp-demo. This fix is only for BERT and may not be compatiable to other models. * added namespace to the constructer * Formatting * Fix tests * Formatting Co-authored-by: Dirk Groeneveld Co-authored-by: Evan Pete Walsh --- allennlp/models/basic_classifier.py | 15 +++++++++++++-- allennlp/tests/commands/predict_test.py | 19 +++++++++++++------ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/allennlp/models/basic_classifier.py b/allennlp/models/basic_classifier.py index 4ac5d930659..d7c00c0f5b3 100644 --- a/allennlp/models/basic_classifier.py +++ b/allennlp/models/basic_classifier.py @@ -6,7 +6,7 @@ from allennlp.data import TextFieldTensors, Vocabulary from allennlp.models.model import Model from allennlp.modules import FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder -from allennlp.nn import InitializerApplicator +from allennlp.nn import InitializerApplicator, util from allennlp.nn.util import get_text_field_mask from allennlp.training.metrics import CategoricalAccuracy @@ -57,6 +57,7 @@ def __init__( dropout: float = None, num_labels: int = None, label_namespace: str = "labels", + namespace: str = "tokens", initializer: InitializerApplicator = InitializerApplicator(), **kwargs, ) -> None: @@ -81,6 +82,7 @@ def __init__( else: self._dropout = None self._label_namespace = label_namespace + self._namespace = namespace if num_labels: self._num_labels = num_labels @@ -134,7 +136,7 @@ def forward( # type: ignore probs = torch.nn.functional.softmax(logits, dim=-1) output_dict = {"logits": logits, "probs": probs} - + output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(tokens) if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss @@ -163,6 +165,15 @@ def make_output_human_readable( ) classes.append(label_str) output_dict["label"] = classes + tokens = [] + for instance_tokens in output_dict["token_ids"]: + tokens.append( + [ + self.vocab.get_token_from_index(token_id.item(), namespace=self._namespace) + for token_id in instance_tokens + ] + ) + output_dict["tokens"] = tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: diff --git a/allennlp/tests/commands/predict_test.py b/allennlp/tests/commands/predict_test.py index 26e345d06df..ab602aa929c 100644 --- a/allennlp/tests/commands/predict_test.py +++ b/allennlp/tests/commands/predict_test.py @@ -85,7 +85,7 @@ def test_works_with_known_model(self): assert len(results) == 2 for result in results: - assert set(result.keys()) == {"label", "logits", "probs"} + assert set(result.keys()) == {"label", "logits", "probs", "tokens", "token_ids"} shutil.rmtree(self.tempdir) @@ -111,7 +111,7 @@ def test_using_dataset_reader_works_with_known_model(self): assert len(results) == 3 for result in results: - assert set(result.keys()) == {"label", "logits", "loss", "probs"} + assert set(result.keys()) == {"label", "logits", "loss", "probs", "tokens", "token_ids"} shutil.rmtree(self.tempdir) @@ -247,7 +247,7 @@ def test_base_predictor(self): assert len(results) == 3 for result in results: - assert set(result.keys()) == {"logits", "probs", "label", "loss"} + assert set(result.keys()) == {"logits", "probs", "label", "loss", "tokens", "token_ids"} DEFAULT_PREDICTORS["basic_classifier"] = "text_classifier" def test_batch_prediction_works_with_known_model(self): @@ -275,7 +275,7 @@ def test_batch_prediction_works_with_known_model(self): assert len(results) == 2 for result in results: - assert set(result.keys()) == {"label", "logits", "probs"} + assert set(result.keys()) == {"label", "logits", "probs", "tokens", "token_ids"} shutil.rmtree(self.tempdir) @@ -326,7 +326,14 @@ def predict_json(self, inputs: JsonDict) -> JsonDict: assert len(results) == 2 # Overridden predictor should output extra field for result in results: - assert set(result.keys()) == {"label", "logits", "explicit", "probs"} + assert set(result.keys()) == { + "label", + "logits", + "explicit", + "probs", + "tokens", + "token_ids", + } shutil.rmtree(self.tempdir) @@ -385,7 +392,7 @@ def test_other_modules(self): assert len(results) == 2 # Overridden predictor should output extra field for result in results: - assert set(result.keys()) == {"label", "logits", "probs"} + assert set(result.keys()) == {"label", "logits", "probs", "tokens", "token_ids"} def test_alternative_file_formats(self): @Predictor.register("classification-csv")