Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Make SrlBert model use SrlEvalMetric (#3168)
Browse files Browse the repository at this point in the history
* Switch SemanticRoleLabeler metric to SrlEvalScorer.

* Switch back to https links, per f9e2029

* Add ignore_classes to SrlEvalScorer and ignore V in SRL model

* Enable specifying path to srl-eval.pl

* Only run span metric if it is enabled, and during evaluation

* Add comment explaining ignore_classes

* Add doc for srl_util

* Add srl_util.rst to allennlp.models.rst

* Fix position of comment

* Use SrlEvalMetric in SrlBert
  • Loading branch information
nelson-liu authored and DeNeutoy committed Aug 19, 2019
1 parent adad1bc commit 0f6b3b8
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions allennlp/models/srl_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.models.srl_util import convert_bio_tags_to_conll_format
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn.util import get_lengths_from_binary_sequence_mask, viterbi_decode
from allennlp.training.metrics import SpanBasedF1Measure
from allennlp.training.metrics.srl_eval_scorer import SrlEvalScorer, DEFAULT_SRL_EVAL_PATH

@Model.register("srl_bert")
class SrlBert(Model):
Expand All @@ -31,6 +32,9 @@ class SrlBert(Model):
Whether or not to use label smoothing on the labels when computing cross entropy loss.
ignore_span_metric: ``bool``, optional (default = False)
Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
srl_eval_path: ``str``, optional (default=``DEFAULT_SRL_EVAL_PATH``)
The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
which is located at allennlp/tools/srl-eval.pl . If ``None``, srl-eval.pl is not used.
"""
def __init__(self,
vocab: Vocabulary,
Expand All @@ -39,7 +43,8 @@ def __init__(self,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None,
label_smoothing: float = None,
ignore_span_metric: bool = False) -> None:
ignore_span_metric: bool = False,
srl_eval_path: str = DEFAULT_SRL_EVAL_PATH) -> None:
super(SrlBert, self).__init__(vocab, regularizer)

if isinstance(bert_model, str):
Expand All @@ -48,9 +53,12 @@ def __init__(self,
self.bert_model = bert_model

self.num_classes = self.vocab.get_vocab_size("labels")
# For the span based evaluation, we don't want to consider labels
# for verb, because the verb index is provided to the model.
self.span_metric = SpanBasedF1Measure(vocab, tag_namespace="labels", ignore_classes=["V"])
if srl_eval_path is not None:
# For the span based evaluation, we don't want to consider labels
# for verb, because the verb index is provided to the model.
self.span_metric = SrlEvalScorer(srl_eval_path, ignore_classes=["V"])
else:
self.span_metric = None
self.tag_projection_layer = Linear(self.bert_model.config.hidden_size, self.num_classes)

self.embedding_dropout = Dropout(p=embedding_dropout)
Expand Down Expand Up @@ -110,25 +118,38 @@ def forward(self, # type: ignore
sequence_length,
self.num_classes])
output_dict = {"logits": logits, "class_probabilities": class_probabilities}
if tags is not None:
loss = sequence_cross_entropy_with_logits(logits,
tags,
mask,
label_smoothing=self._label_smoothing)
if not self.ignore_span_metric:
self.span_metric(class_probabilities, tags, mask)
output_dict["loss"] = loss

# We need to retain the mask in the output dictionary
# so that we can crop the sequences to remove padding
# when we do viterbi inference in self.decode.
output_dict["mask"] = mask

# We add in the offsets here so we can compute the un-wordpieced tags.
words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata])
output_dict["words"] = list(words)
output_dict["verb"] = list(verbs)
output_dict["wordpiece_offsets"] = list(offsets)

if tags is not None:
loss = sequence_cross_entropy_with_logits(logits,
tags,
mask,
label_smoothing=self._label_smoothing)
if not self.ignore_span_metric and self.span_metric is not None and not self.training:
batch_verb_indices = [example_metadata["verb_index"] for example_metadata in metadata]
batch_sentences = [example_metadata["words"] for example_metadata in metadata]
# Get the BIO tags from decode()
# TODO (nfliu): This is kind of a hack, consider splitting out part
# of decode() to a separate function.
batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
batch_conll_predicted_tags = [convert_bio_tags_to_conll_format(tags) for
tags in batch_bio_predicted_tags]
batch_bio_gold_tags = [example_metadata["gold_tags"] for example_metadata in metadata]
batch_conll_gold_tags = [convert_bio_tags_to_conll_format(tags) for
tags in batch_bio_gold_tags]
self.span_metric(batch_verb_indices,
batch_sentences,
batch_conll_predicted_tags,
batch_conll_gold_tags)
output_dict["loss"] = loss
return output_dict

@overrides
Expand Down

0 comments on commit 0f6b3b8

Please sign in to comment.