Skip to content
Permalink
Browse files

Fix span masking (#905)

* fix a few things with the parser

* pylint

* fix sphinx

* update variable name, try new thing for sphinx
  • Loading branch information...
DeNeutoy committed Feb 21, 2018
1 parent 6f4de85 commit bca992ef8d73da88b0a81707f79bf2759a8158dc
Showing with 32 additions and 13 deletions.
  1. +31 −12 allennlp/models/constituency_parser.py
  2. +1 −1 tests/models/constituency_parser_test.py
@@ -142,22 +142,32 @@ def forward(self, # type: ignore
Returns
-------
An output dictionary consisting of:
logits : ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
representing unnormalised log probabilities of the label classes for each span.
class_probabilities : ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
representing a distribution over the label classes per span.
spans : ``torch.LongTensor``
The original spans tensor.
tokens : ``torch.LongTensor``
The token ids from the ``TextField``. Has shape (batch_size, num_tokens).
sentence_lengths : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of the non-padded
elements of ``sentences``.
num_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of non-padded spans
in ``enumerated_spans``.
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised.
"""

embedded_text_input = self.text_field_embedder(tokens)
mask = get_text_field_mask(tokens)
sentence_lengths = get_lengths_from_binary_sequence_mask(mask)
# Looking at the span start index is enough to know if
# this is padding or not. Shape: (batch_size, num_spans)
span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()

sentence_lengths = get_lengths_from_binary_sequence_mask(mask)
num_spans = get_lengths_from_binary_sequence_mask(span_mask)

encoded_text = self.encoder(embedded_text_input, mask)
span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
if self.feedforward_layer is not None:
@@ -170,7 +180,8 @@ def forward(self, # type: ignore
"spans": spans,
# TODO(Mark): This relies on having tokens represented with a SingleIdTokenIndexer...
"tokens": tokens["tokens"],
"sentence_lengths": sentence_lengths
"sentence_lengths": sentence_lengths,
"num_spans": num_spans
}
if span_labels is not None:
loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
@@ -184,7 +195,8 @@ def forward(self, # type: ignore
predicted_trees = self.construct_trees(class_probabilities.cpu().data,
spans.cpu().data,
tokens["tokens"].cpu().data,
sentence_lengths.cpu().data)
sentence_lengths.cpu().data,
num_spans.data)
self._evalb_score(predicted_trees, gold_tree)

return output_dict
@@ -195,22 +207,25 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
Constructs an NLTK ``Tree`` given the scored spans. We also switch to exclusive
span ends when constructing the tree representation, because it makes indexing
into lists cleaner for ranges of text, rather than individual indices.
"""
all_predictions = output_dict['class_probabilities'].cpu().data
all_spans = output_dict["spans"].cpu().data

all_sentences = output_dict["tokens"].cpu().data
sentence_lengths = output_dict["sentence_lengths"].data
trees = self.construct_trees(all_predictions, all_spans, all_sentences, sentence_lengths)
num_spans = output_dict["num_spans"].data
trees = self.construct_trees(all_predictions, all_spans, all_sentences, sentence_lengths, num_spans)

output_dict["trees"] = trees
return output_dict

def construct_trees(self,
predictions: torch.FloatTensor,
enumerated_spans: torch.LongTensor,
all_spans: torch.LongTensor,
sentences: torch.LongTensor,
sentence_lengths: torch.LongTensor) -> List[Tree]:
sentence_lengths: torch.LongTensor,
num_spans: torch.LongTensor) -> List[Tree]:
"""
Construct ``nltk.Tree``'s for each batch element by greedily nesting spans.
The trees use exclusive end indices, which contrasts with how spans are
@@ -221,7 +236,7 @@ def construct_trees(self,
predictions : ``torch.FloatTensor``, required.
A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
representing a distribution over the label classes per span.
enumerated_spans : ``torch.LongTensor``, required.
all_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size, num_spans, 2), representing the span
indices we scored.
sentences : ``torch.LongTensor``, required.
@@ -230,13 +245,16 @@ def construct_trees(self,
sentence_lengths : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of the non-padded
elements of ``sentences``.
num_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of non-padded spans
in ``enumerated_spans``.
Returns
-------
A ``List[Tree]`` containing the decoded trees for each element in the batch.
"""
# Switch to using exclusive end spans.
exclusive_end_spans = enumerated_spans.clone()
exclusive_end_spans = all_spans.clone()
exclusive_end_spans[:, :, -1] += 1
no_label_id = self.vocab.get_token_index("NO-LABEL", "labels")

@@ -248,7 +266,8 @@ def construct_trees(self,
index in sentence_ids[:sentence_lengths[batch_index]]]

selected_spans = []
for prediction, span in zip(scored_spans, spans):
for prediction, span in zip(scored_spans[:num_spans[batch_index]],
spans[:num_spans[batch_index]]):
start, end = span
no_label_prob = prediction[no_label_id]
label_prob, label_index = torch.max(prediction, -1)
@@ -30,7 +30,7 @@ def test_decode_runs(self):
output_dict = self.model(**training_tensors)
decode_output_dict = self.model.decode(output_dict)
assert set(decode_output_dict.keys()) == {'spans', 'class_probabilities', 'trees',
'tokens', 'sentence_lengths', 'loss'}
'tokens', 'sentence_lengths', 'num_spans', 'loss'}
metrics = self.model.get_metrics(reset=True)
metric_keys = set(metrics.keys())
assert "evalb_precision" in metric_keys

0 comments on commit bca992e

Please sign in to comment.
You can’t perform that action at this time.