This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
span_extractor.py
68 lines (57 loc) · 2.64 KB
/
span_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from overrides import overrides
from allennlp.common.registrable import Registrable
class SpanExtractor(torch.nn.Module, Registrable):
"""
Many NLP models deal with representations of spans inside a sentence.
SpanExtractors define methods for extracting and representing spans
from a sentence.
SpanExtractors take a sequence tensor of shape (batch_size, timesteps, embedding_dim)
and indices of shape (batch_size, num_spans, 2) and return a tensor of
shape (batch_size, num_spans, ...), forming some representation of the
spans.
"""
@overrides
def forward(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
):
"""
Given a sequence tensor, extract spans and return representations of
them. Span representation can be computed in many different ways,
such as concatenation of the start and end spans, attention over the
vectors contained inside the span, etc.
# Parameters
sequence_tensor : `torch.FloatTensor`, required.
A tensor of shape (batch_size, sequence_length, embedding_size)
representing an embedded sequence of words.
span_indices : `torch.LongTensor`, required.
A tensor of shape `(batch_size, num_spans, 2)`, where the last
dimension represents the inclusive start and end indices of the
span to be extracted from the `sequence_tensor`.
sequence_mask : `torch.BoolTensor`, optional (default = `None`).
A tensor of shape (batch_size, sequence_length) representing padded
elements of the sequence.
span_indices_mask : `torch.BoolTensor`, optional (default = `None`).
A tensor of shape (batch_size, num_spans) representing the valid
spans in the `indices` tensor. This mask is optional because
sometimes it's easier to worry about masking after calling this
function, rather than passing a mask directly.
# Returns
A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
where `embedded_span_size` depends on the way spans are represented.
"""
raise NotImplementedError
def get_input_dim(self) -> int:
"""
Returns the expected final dimension of the `sequence_tensor`.
"""
raise NotImplementedError
def get_output_dim(self) -> int:
"""
Returns the expected final dimension of the returned span representation.
"""
raise NotImplementedError