Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Make SpanBasedF1Measure support BMES (#1692)
Browse files Browse the repository at this point in the history
* Make SpanBasedF1Measure support BMES.

* Decoration change.

* Fix code style.

* Add test for SpanBasedF1Measure(label_encoding="BMES").

* Bugfix.

* Fix testcase.

* Sse text[2:] instead of text.partition('-')[2]
  • Loading branch information
Haoxun Zhan authored and matt-gardner committed Aug 30, 2018
1 parent d16f6c0 commit 335d899
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 4 deletions.
69 changes: 69 additions & 0 deletions allennlp/data/dataset_readers/dataset_utils/span_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,72 @@ def process_stack(stack, out_stack):
process_stack(stack, bioul_sequence)

return bioul_sequence


def bmes_tags_to_spans(tag_sequence: List[str],
classes_to_ignore: List[str] = None) -> List[TypedStringSpan]:
"""
Given a sequence corresponding to BMES tags, extracts spans.
Spans are inclusive and can be of zero length, representing a single word span.
Ill-formed spans are not allowed and will raise ``InvalidTagSequence``.
This function works properly when the spans are unlabeled (i.e., your labels are
simply "B", "M", "E" and "S").
Parameters
----------
tag_sequence : List[str], required.
The integer class labels for a sequence.
classes_to_ignore : List[str], optional (default = None).
A list of string class labels `excluding` the bio tag
which should be ignored when extracting spans.
Returns
-------
spans : List[TypedStringSpan]
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
Note that the label `does not` contain any BIO tag prefixes.
"""
def extract_bmes_tag_label(text):
bmes_tag = text[0]
label = text[2:]
return bmes_tag, label

spans = []
classes_to_ignore = classes_to_ignore or []
invalid = False
index = 0
while index < len(tag_sequence) and not invalid:
start_bmes_tag, start_label = extract_bmes_tag_label(tag_sequence[index])
start_index = index

if start_bmes_tag == 'B':
index += 1
while index < len(tag_sequence):
bmes_tag, label = extract_bmes_tag_label(tag_sequence[index])
# Stop conditions.
if label != start_label or bmes_tag not in ('M', 'E'):
invalid = True
break
if bmes_tag == 'E':
break
# bmes_tag == 'M', move to next.
index += 1

if index >= len(tag_sequence):
invalid = True
if not invalid:
spans.append((start_label, (start_index, index)))

elif start_bmes_tag == 'S':
spans.append((start_label, (start_index, start_index)))

else:
invalid = True

# Move to next span.
index += 1

if invalid:
raise InvalidTagSequence(tag_sequence)

return [span for span in spans if span[0] not in classes_to_ignore]
2 changes: 1 addition & 1 deletion allennlp/models/crf_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CrfTagger(Model):
An optional feedforward layer to apply after the encoder.
label_encoding : ``str``, optional (default=``None``)
Label encoding to use when calculating span f1 and constraining
the CRF at decoding time . Valid options are "BIO", "BIOUL", "IOB1".
the CRF at decoding time . Valid options are "BIO", "BIOUL", "IOB1", "BMES".
Required if ``calculate_span_f1`` or ``constrain_crf_decoding`` is true.
constraint_type : ``str``, optional (default=``None``)
If provided, the CRF will be constrained at decoding time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,45 @@ def test_bio_to_bioul(self):
with self.assertRaises(span_utils.InvalidTagSequence):
tag_sequence = ['O', 'I-PER', 'B-PER', 'I-PER', 'I-PER', 'B-PER']
bioul_sequence = span_utils.to_bioul(tag_sequence, encoding="BIO")

def test_bmes_tags_to_spans_extracts_correct_spans(self):
tag_sequence = ["B-ARG1", "M-ARG1", "E-ARG1", "B-ARG2", "E-ARG2", "S-ARG3"]
spans = span_utils.bmes_tags_to_spans(tag_sequence)
assert set(spans) == {("ARG1", (0, 2)), ("ARG2", (3, 4)), ("ARG3", (5, 5))}

tag_sequence = ["S-ARG1", "B-ARG2", "E-ARG2", "S-ARG3"]
spans = span_utils.bmes_tags_to_spans(tag_sequence)
assert set(spans) == {("ARG1", (0, 0)), ("ARG2", (1, 2)), ("ARG3", (3, 3))}

# Check that it raises when labels are not correct.
tag_sequence = ["B-ARG1", "M-ARG2", "E-ARG1"]
with self.assertRaises(span_utils.InvalidTagSequence):
spans = span_utils.bmes_tags_to_spans(tag_sequence)

# Check that it raises when tag transitions are not correct.
tag_sequence = ["B-ARG1", "B-ARG1"]
with self.assertRaises(span_utils.InvalidTagSequence):
spans = span_utils.bmes_tags_to_spans(tag_sequence)
tag_sequence = ["B-ARG1", "S-ARG1"]
with self.assertRaises(span_utils.InvalidTagSequence):
spans = span_utils.bmes_tags_to_spans(tag_sequence)

def test_bmes_tags_to_spans_extracts_correct_spans_without_labels(self):
tag_sequence = ["B", "M", "E", "B", "E", "S"]
spans = span_utils.bmes_tags_to_spans(tag_sequence)
assert set(spans) == {("", (0, 2)), ("", (3, 4)), ("", (5, 5))}

tag_sequence = ["S", "B", "E", "S"]
spans = span_utils.bmes_tags_to_spans(tag_sequence)
assert set(spans) == {("", (0, 0)), ("", (1, 2)), ("", (3, 3))}

# Check that it raises when tag transitions are not correct.
tag_sequence = ["B", "B"]
with self.assertRaises(span_utils.InvalidTagSequence):
spans = span_utils.bmes_tags_to_spans(tag_sequence)
tag_sequence = ["B", "S"]
with self.assertRaises(span_utils.InvalidTagSequence):
spans = span_utils.bmes_tags_to_spans(tag_sequence)
tag_sequence = ["B", "E", "M"]
with self.assertRaises(span_utils.InvalidTagSequence):
spans = span_utils.bmes_tags_to_spans(tag_sequence)
40 changes: 40 additions & 0 deletions allennlp/tests/training/metrics/span_based_f1_measure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def setUp(self):
vocab.add_token_to_namespace("B-ARGM-ADJ", "tags")
vocab.add_token_to_namespace("I-ARGM-ADJ", "tags")

# BMES.
vocab.add_token_to_namespace("B", "bmes_tags")
vocab.add_token_to_namespace("M", "bmes_tags")
vocab.add_token_to_namespace("E", "bmes_tags")
vocab.add_token_to_namespace("S", "bmes_tags")

self.vocab = vocab

def test_span_metrics_are_computed_correcly_with_prediction_map(self):
Expand Down Expand Up @@ -167,6 +173,40 @@ def test_span_metrics_are_computed_correctly(self):
numpy.testing.assert_almost_equal(metric_dict["precision-overall"], 0.5)
numpy.testing.assert_almost_equal(metric_dict["f1-measure-overall"], 0.5)

def test_bmes_span_metrics_are_computed_correctly(self):
# (bmes_tags) B:0, M:1, E:2, S:3.
# [S, B, M, E, S]
# [S, S, S, S, S]
gold_indices = [[3, 0, 1, 2, 3],
[3, 3, 3, 3, 3]]
gold_tensor = torch.Tensor(gold_indices)

prediction_tensor = torch.rand([2, 5, 4])
# [S, B, E, S, S]
# TP: 2, FP: 2, FN: 1.
prediction_tensor[0, 0, 3] = 1 # (True positive)
prediction_tensor[0, 1, 0] = 1 # (False positive
prediction_tensor[0, 2, 2] = 1 # *)
prediction_tensor[0, 3, 3] = 1 # (False positive)
prediction_tensor[0, 4, 3] = 1 # (True positive)
# [B, E, S, B, E]
# TP: 1, FP: 2, FN: 4.
prediction_tensor[1, 0, 0] = 1 # (False positive
prediction_tensor[1, 1, 2] = 1 # *)
prediction_tensor[1, 2, 3] = 1 # (True positive)
prediction_tensor[1, 3, 0] = 1 # (False positive
prediction_tensor[1, 4, 2] = 1 # *)

metric = SpanBasedF1Measure(self.vocab, "bmes_tags", label_encoding="BMES")
metric(prediction_tensor, gold_tensor)

# TP: 3, FP: 4, FN: 5.
metric_dict = metric.get_metric()

numpy.testing.assert_almost_equal(metric_dict["recall-overall"], 0.375)
numpy.testing.assert_almost_equal(metric_dict["precision-overall"], 0.428, decimal=3)
numpy.testing.assert_almost_equal(metric_dict["f1-measure-overall"], 0.4)

def test_span_f1_can_build_from_params(self):
params = Params({"type": "span_f1", "tag_namespace": "tags", "ignore_classes": ["V"]})
metric = Metric.from_params(params=params, vocabulary=self.vocab)
Expand Down
10 changes: 7 additions & 3 deletions allennlp/training/metrics/span_based_f1_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
bio_tags_to_spans,
bioul_tags_to_spans,
iob1_tags_to_spans,
bmes_tags_to_spans,
TypedStringSpan
)

Expand Down Expand Up @@ -58,10 +59,10 @@ def __init__(self,
spans in a BIO tagging scheme which are typically not included.
label_encoding : ``str``, optional (default = "BIO")
The encoding used to specify label span endpoints in the sequence.
Valid options are "BIO", "IOB1", or BIOUL".
Valid options are "BIO", "IOB1", "BIOUL" or "BMES".
"""
if label_encoding not in ["BIO", "IOB1", "BIOUL"]:
raise ConfigurationError("Unknown label encoding - expected 'BIO', 'IOB1', 'BIOUL'.")
if label_encoding not in ["BIO", "IOB1", "BIOUL", "BMES"]:
raise ConfigurationError("Unknown label encoding - expected 'BIO', 'IOB1', 'BIOUL', 'BMES'.")

self._label_encoding = label_encoding
self._label_vocabulary = vocabulary.get_index_to_token_vocabulary(tag_namespace)
Expand Down Expand Up @@ -143,6 +144,9 @@ def __call__(self,
elif self._label_encoding == "BIOUL":
predicted_spans = bioul_tags_to_spans(predicted_string_labels, self._ignore_classes)
gold_spans = bioul_tags_to_spans(gold_string_labels, self._ignore_classes)
elif self._label_encoding == "BMES":
predicted_spans = bmes_tags_to_spans(predicted_string_labels, self._ignore_classes)
gold_spans = bmes_tags_to_spans(gold_string_labels, self._ignore_classes)

predicted_spans = self._handle_continued_spans(predicted_spans)
gold_spans = self._handle_continued_spans(gold_spans)
Expand Down

0 comments on commit 335d899

Please sign in to comment.