This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
utils.py
42 lines (35 loc) · 2.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from allennlp.nn.util import replace_masked_values, min_value_of_dtype
def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
"""
This acts the same as the static method ``BidirectionalAttentionFlow.get_best_span()``
in ``allennlp/models/reading_comprehension/bidaf.py``. We keep it here so that users can
directly import this function without the class.
We call the inputs "logits" - they could either be unnormalized logits or normalized log
probabilities. A log_softmax operation is a constant shifting of the entire logit
vector, so taking an argmax over either one gives the same result.
"""
if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
raise ValueError("Input shapes must be (batch_size, passage_length)")
batch_size, passage_length = span_start_logits.size()
device = span_start_logits.device
# (batch_size, passage_length, passage_length)
span_log_probs = span_start_logits.unsqueeze(2) + span_end_logits.unsqueeze(1)
# Only the upper triangle of the span matrix is valid; the lower triangle has entries where
# the span ends before it starts.
span_log_mask = torch.triu(torch.ones((passage_length, passage_length), device=device)).log()
valid_span_log_probs = span_log_probs + span_log_mask
# Here we take the span matrix and flatten it, then find the best span using argmax. We
# can recover the start and end indices from this flattened list using simple modular
# arithmetic.
# (batch_size, passage_length * passage_length)
best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
span_start_indices = torch.div(best_spans, passage_length, rounding_mode="trunc")
span_end_indices = best_spans % passage_length
return torch.stack([span_start_indices, span_end_indices], dim=-1)
def replace_masked_values_with_big_negative_number(x: torch.Tensor, mask: torch.Tensor):
"""
Replace the masked values in a tensor something really negative so that they won't
affect a max operation.
"""
return replace_masked_values(x, mask, min_value_of_dtype(x.dtype))