This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
sequence_label_field.py
156 lines (132 loc) · 6.3 KB
/
sequence_label_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Dict, List, Union, Set, Iterator
import logging
import textwrap
from overrides import overrides
import torch
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import pad_sequence_to_length
from allennlp.data.fields.field import Field
from allennlp.data.fields.sequence_field import SequenceField
from allennlp.data.vocabulary import Vocabulary
logger = logging.getLogger(__name__)
class SequenceLabelField(Field[torch.Tensor]):
"""
A `SequenceLabelField` assigns a categorical label to each element in a
:class:`~allennlp.data.fields.sequence_field.SequenceField`.
Because it's a labeling of some other field, we take that field as input here, and we use it to
determine our padding and other things.
This field will get converted into a list of integer class ids, representing the correct class
for each element in the sequence.
# Parameters
labels : `Union[List[str], List[int]]`
A sequence of categorical labels, encoded as strings or integers. These could be POS tags like [NN,
JJ, ...], BIO tags like [B-PERS, I-PERS, O, O, ...], or any other categorical tag sequence. If the
labels are encoded as integers, they will not be indexed using a vocab.
sequence_field : `SequenceField`
A field containing the sequence that this `SequenceLabelField` is labeling. Most often, this is a
`TextField`, for tagging individual tokens in a sentence.
label_namespace : `str`, optional (default=`'labels'`)
The namespace to use for converting tag strings into integers. We convert tag strings to
integers for you, and this parameter tells the `Vocabulary` object which mapping from
strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word).
"""
__slots__ = [
"labels",
"sequence_field",
"_label_namespace",
"_indexed_labels",
"_skip_indexing",
]
# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
# This warning will be repeated for every instantiation of this class (i.e for every data
# instance), spewing a lot of warnings so this class variable is used to only log a single
# warning per namespace.
_already_warned_namespaces: Set[str] = set()
def __init__(
self,
labels: Union[List[str], List[int]],
sequence_field: SequenceField,
label_namespace: str = "labels",
) -> None:
self.labels = labels
self.sequence_field = sequence_field
self._label_namespace = label_namespace
self._indexed_labels = None
self._maybe_warn_for_namespace(label_namespace)
if len(labels) != sequence_field.sequence_length():
raise ConfigurationError(
"Label length and sequence length "
"don't match: %d and %d" % (len(labels), sequence_field.sequence_length())
)
self._skip_indexing = False
if all(isinstance(x, int) for x in labels):
self._indexed_labels = labels
self._skip_indexing = True
elif not all(isinstance(x, str) for x in labels):
raise ConfigurationError(
"SequenceLabelFields must be passed either all "
"strings or all ints. Found labels {} with "
"types: {}.".format(labels, [type(x) for x in labels])
)
def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
if label_namespace not in self._already_warned_namespaces:
logger.warning(
"Your label namespace was '%s'. We recommend you use a namespace "
"ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
"default to your vocabulary. See documentation for "
"`non_padded_namespaces` parameter in Vocabulary.",
self._label_namespace,
)
self._already_warned_namespaces.add(label_namespace)
# Sequence methods
def __iter__(self) -> Iterator[Union[str, int]]:
return iter(self.labels)
def __getitem__(self, idx: int) -> Union[str, int]:
return self.labels[idx]
def __len__(self) -> int:
return len(self.labels)
@overrides
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
if self._indexed_labels is None:
for label in self.labels:
counter[self._label_namespace][label] += 1 # type: ignore
@overrides
def index(self, vocab: Vocabulary):
if not self._skip_indexing:
self._indexed_labels = [
vocab.get_token_index(label, self._label_namespace) # type: ignore
for label in self.labels
]
@overrides
def get_padding_lengths(self) -> Dict[str, int]:
return {"num_tokens": self.sequence_field.sequence_length()}
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
if self._indexed_labels is None:
raise ConfigurationError(
"You must call .index(vocabulary) on a field before calling .as_tensor()"
)
desired_num_tokens = padding_lengths["num_tokens"]
padded_tags = pad_sequence_to_length(self._indexed_labels, desired_num_tokens)
tensor = torch.LongTensor(padded_tags)
return tensor
@overrides
def empty_field(self) -> "SequenceLabelField":
# The empty_list here is needed for mypy
empty_list: List[str] = []
sequence_label_field = SequenceLabelField(empty_list, self.sequence_field.empty_field())
sequence_label_field._indexed_labels = empty_list
return sequence_label_field
def __str__(self) -> str:
length = self.sequence_field.sequence_length()
formatted_labels = "".join(
"\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.labels), 100)
)
return (
f"SequenceLabelField of length {length} with "
f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
)
@overrides
def human_readable_repr(self) -> Union[List[str], List[int]]:
return self.labels