-
Notifications
You must be signed in to change notification settings - Fork 214
/
bert_crf_tagger.py
233 lines (210 loc) · 11.1 KB
/
bert_crf_tagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from typing import Dict, Optional, List, Any
from overrides import overrides
import torch
from torch.nn.modules.linear import Linear
from allennlp.common.checks import check_dimensions_match, ConfigurationError
from allennlp.data import Vocabulary
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
from allennlp.modules import ConditionalRandomField, FeedForward
from allennlp.modules.conditional_random_field import allowed_transitions
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator, RegularizerApplicator
import allennlp.nn.util as util
from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure, F1Measure
@Model.register("bert_crf_tagger")
class BertCrfTagger(Model):
"""
The ``BertCrfTagger`` encodes a sequence of text with a ``Seq2SeqEncoder``,
then uses a Conditional Random Field model to predict a tag for each token in the sequence.
Parameters
----------
vocab : ``Vocabulary``, required
A Vocabulary, required in order to compute sizes for input/output projections.
text_field_embedder : ``TextFieldEmbedder``, required
Used to embed the tokens ``TextField`` we get as input to the model.
encoder : ``Seq2SeqEncoder``
The encoder that we will use in between embedding tokens and predicting output tags.
label_namespace : ``str``, optional (default=``labels``)
This is needed to compute the SpanBasedF1Measure metric.
Unless you did something unusual, the default value should be what you want.
feedforward : ``FeedForward``, optional, (default = None).
An optional feedforward layer to apply after the encoder.
label_encoding : ``str``, optional (default=``None``)
Label encoding to use when calculating span f1 and constraining
the CRF at decoding time . Valid options are "BIO", "BIOUL", "IOB1", "BMES".
Required if ``calculate_span_f1`` or ``constrain_crf_decoding`` is true.
include_start_end_transitions : ``bool``, optional (default=``True``)
Whether to include start and end transition parameters in the CRF.
constrain_crf_decoding : ``bool``, optional (default=``None``)
If ``True``, the CRF is constrained at decoding time to
produce valid sequences of tags. If this is ``True``, then
``label_encoding`` is required. If ``None`` and
label_encoding is specified, this is set to ``True``.
If ``None`` and label_encoding is not specified, it defaults
to ``False``.
calculate_span_f1 : ``bool``, optional (default=``None``)
Calculate span-level F1 metrics during training. If this is ``True``, then
``label_encoding`` is required. If ``None`` and
label_encoding is specified, this is set to ``True``.
If ``None`` and label_encoding is not specified, it defaults
to ``False``.
dropout: ``float``, optional (default=``None``)
verbose_metrics : ``bool``, optional (default = False)
If true, metrics will be returned per label class in addition
to the overall statistics.
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
Used to initialize the model parameters.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self, vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
label_namespace: str = "labels",
label_encoding: Optional[str] = None,
include_start_end_transitions: bool = True,
constrain_crf_decoding: bool = None,
calculate_span_f1: bool = None,
dropout: float = 0.1,
verbose_metrics: bool = False,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None) -> None:
super().__init__(vocab, regularizer)
self.label_namespace = label_namespace
self.text_field_embedder = text_field_embedder
self.num_tags = self.vocab.get_vocab_size(label_namespace)
self._verbose_metrics = verbose_metrics
self.dropout = torch.nn.Dropout(dropout)
self.tag_projection_layer = TimeDistributed(
Linear(self.text_field_embedder.get_output_dim(), self.num_tags)
)
# if constrain_crf_decoding and calculate_span_f1 are not
# provided, (i.e., they're None), set them to True
# if label_encoding is provided and False if it isn't.
if constrain_crf_decoding is None:
constrain_crf_decoding = label_encoding is not None
if calculate_span_f1 is None:
calculate_span_f1 = label_encoding is not None
self.label_encoding = label_encoding
if constrain_crf_decoding:
if not label_encoding:
raise ConfigurationError("constrain_crf_decoding is True, but "
"no label_encoding was specified.")
labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
constraints = allowed_transitions(label_encoding, labels)
else:
constraints = None
self.include_start_end_transitions = include_start_end_transitions
self.crf = ConditionalRandomField(
self.num_tags, constraints,
include_start_end_transitions=include_start_end_transitions
)
self.metrics = {
"accuracy": CategoricalAccuracy(),
"accuracy3": CategoricalAccuracy(top_k=3)
}
for index, label in self.vocab.get_index_to_token_vocabulary(label_namespace).items():
self.metrics['F1_' + label] = F1Measure(positive_label=index)
self.calculate_span_f1 = calculate_span_f1
if calculate_span_f1:
if not label_encoding:
raise ConfigurationError("calculate_span_f1 is True, but "
"no label_encoding was specified.")
self._f1_metric = SpanBasedF1Measure(vocab,
tag_namespace=label_namespace,
label_encoding=label_encoding)
initializer(self)
@overrides
def forward(self, # type: ignore
tokens: Dict[str, torch.LongTensor],
tags: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None,
# pylint: disable=unused-argument
**kwargs) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Parameters
----------
tokens : ``Dict[str, torch.LongTensor]``, required
The output of ``TextField.as_array()``, which should typically be passed directly to a
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
for the ``TokenIndexers`` when you created the ``TextField`` representing your
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
which knows how to combine different word representations into a single vector per
token in your input.
tags : ``torch.LongTensor``, optional (default = ``None``)
A torch tensor representing the sequence of integer gold class labels of shape
``(batch_size, num_tokens)``.
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
metadata containg the original words in the sentence to be tagged under a 'words' key.
Returns
-------
An output dictionary consisting of:
logits : ``torch.FloatTensor``
The logits that are the output of the ``tag_projection_layer``
mask : ``torch.LongTensor``
The text field mask for the input tokens
tags : ``List[List[int]]``
The predicted tags using the Viterbi algorithm.
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
"""
embedded_text_input = self.text_field_embedder(tokens)
mask = util.get_text_field_mask(tokens)
embedded_text_input = self.dropout(embedded_text_input)
logits = self.tag_projection_layer(embedded_text_input)
best_paths = self.crf.viterbi_tags(logits, mask)
# Just get the tags and ignore the score.
predicted_tags = [x for x, y in best_paths]
output = {"logits": logits, "mask": mask, "tags": predicted_tags}
if tags is not None:
# Add negative log-likelihood as loss
log_likelihood = self.crf(logits, tags, mask)
output["loss"] = -log_likelihood
# Represent viterbi tags as "class probabilities" that we can
# feed into the metrics
class_probabilities = logits * 0.
for i, instance_tags in enumerate(predicted_tags):
for j, tag_id in enumerate(instance_tags):
class_probabilities[i, j, tag_id] = 1
for metric in self.metrics.values():
metric(class_probabilities, tags, mask.float())
if self.calculate_span_f1:
self._f1_metric(class_probabilities, tags, mask.float())
if metadata is not None:
output["words"] = [x["words"] for x in metadata]
return output
@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts the tag ids to the actual tags.
``output_dict["tags"]`` is a list of lists of tag_ids,
so we use an ugly nested list comprehension.
"""
output_dict["tags"] = [
[self.vocab.get_token_from_index(tag, namespace=self.label_namespace)
for tag in instance_tags]
for instance_tags in output_dict["tags"]
]
return output_dict
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics_to_return = {}
total_f1, total_classes = 0, 0
for metric_name, metric_obj in self.metrics.items():
if metric_name.startswith('accuracy'):
metrics_to_return[metric_name] = metric_obj.get_metric(reset)
elif metric_name.startswith('F1_'):
p, r, f1 = metric_obj.get_metric(reset)
metrics_to_return[metric_name] = f1
total_f1 += f1
total_classes += 1
metrics_to_return['avg_f1'] = total_f1 / total_classes
if self.calculate_span_f1:
f1_dict = self._f1_metric.get_metric(reset=reset)
if self._verbose_metrics:
metrics_to_return.update(f1_dict)
else:
metrics_to_return.update({ x: y for x, y in f1_dict.items() if "overall" in x})
return metrics_to_return