Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
modified the make_output_human_readable method in basic_classifier fo…
Browse files Browse the repository at this point in the history
…r allennlp-demo (#4038)

* modified the make_output_human_readable method in basic_classifier, so that the outputdict can contain a tokens field for allennlp-demo. This fix is only for BERT and may not be compatiable to other models.

* added namespace to the constructer

* Formatting

* Fix tests

* Formatting

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>
  • Loading branch information
3 people committed Apr 25, 2020
1 parent 6ea6c59 commit d709e59
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
15 changes: 13 additions & 2 deletions allennlp/models/basic_classifier.py
Expand Up @@ -6,7 +6,7 @@
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.models.model import Model
from allennlp.modules import FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
from allennlp.nn import InitializerApplicator
from allennlp.nn import InitializerApplicator, util
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy

Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
dropout: float = None,
num_labels: int = None,
label_namespace: str = "labels",
namespace: str = "tokens",
initializer: InitializerApplicator = InitializerApplicator(),
**kwargs,
) -> None:
Expand All @@ -81,6 +82,7 @@ def __init__(
else:
self._dropout = None
self._label_namespace = label_namespace
self._namespace = namespace

if num_labels:
self._num_labels = num_labels
Expand Down Expand Up @@ -134,7 +136,7 @@ def forward( # type: ignore
probs = torch.nn.functional.softmax(logits, dim=-1)

output_dict = {"logits": logits, "probs": probs}

output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(tokens)
if label is not None:
loss = self._loss(logits, label.long().view(-1))
output_dict["loss"] = loss
Expand Down Expand Up @@ -163,6 +165,15 @@ def make_output_human_readable(
)
classes.append(label_str)
output_dict["label"] = classes
tokens = []
for instance_tokens in output_dict["token_ids"]:
tokens.append(
[
self.vocab.get_token_from_index(token_id.item(), namespace=self._namespace)
for token_id in instance_tokens
]
)
output_dict["tokens"] = tokens
return output_dict

def get_metrics(self, reset: bool = False) -> Dict[str, float]:
Expand Down
19 changes: 13 additions & 6 deletions allennlp/tests/commands/predict_test.py
Expand Up @@ -85,7 +85,7 @@ def test_works_with_known_model(self):

assert len(results) == 2
for result in results:
assert set(result.keys()) == {"label", "logits", "probs"}
assert set(result.keys()) == {"label", "logits", "probs", "tokens", "token_ids"}

shutil.rmtree(self.tempdir)

Expand All @@ -111,7 +111,7 @@ def test_using_dataset_reader_works_with_known_model(self):

assert len(results) == 3
for result in results:
assert set(result.keys()) == {"label", "logits", "loss", "probs"}
assert set(result.keys()) == {"label", "logits", "loss", "probs", "tokens", "token_ids"}

shutil.rmtree(self.tempdir)

Expand Down Expand Up @@ -247,7 +247,7 @@ def test_base_predictor(self):

assert len(results) == 3
for result in results:
assert set(result.keys()) == {"logits", "probs", "label", "loss"}
assert set(result.keys()) == {"logits", "probs", "label", "loss", "tokens", "token_ids"}
DEFAULT_PREDICTORS["basic_classifier"] = "text_classifier"

def test_batch_prediction_works_with_known_model(self):
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_batch_prediction_works_with_known_model(self):

assert len(results) == 2
for result in results:
assert set(result.keys()) == {"label", "logits", "probs"}
assert set(result.keys()) == {"label", "logits", "probs", "tokens", "token_ids"}

shutil.rmtree(self.tempdir)

Expand Down Expand Up @@ -326,7 +326,14 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
assert len(results) == 2
# Overridden predictor should output extra field
for result in results:
assert set(result.keys()) == {"label", "logits", "explicit", "probs"}
assert set(result.keys()) == {
"label",
"logits",
"explicit",
"probs",
"tokens",
"token_ids",
}

shutil.rmtree(self.tempdir)

Expand Down Expand Up @@ -385,7 +392,7 @@ def test_other_modules(self):
assert len(results) == 2
# Overridden predictor should output extra field
for result in results:
assert set(result.keys()) == {"label", "logits", "probs"}
assert set(result.keys()) == {"label", "logits", "probs", "tokens", "token_ids"}

def test_alternative_file_formats(self):
@Predictor.register("classification-csv")
Expand Down

0 comments on commit d709e59

Please sign in to comment.