This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
span_field.py
75 lines (57 loc) · 2.68 KB
/
span_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Dict
from overrides import overrides
import torch
from allennlp.data.fields.field import Field
from allennlp.data.fields.sequence_field import SequenceField
class SpanField(Field[torch.Tensor]):
"""
A `SpanField` is a pair of inclusive, zero-indexed (start, end) indices into a
:class:`~allennlp.data.fields.sequence_field.SequenceField`, used to represent a span of text.
Because it's a pair of indices into a :class:`SequenceField`, we take one of those as input
to make the span's dependence explicit and to validate that the span is well defined.
# Parameters
span_start : `int`, required.
The index of the start of the span in the :class:`SequenceField`.
span_end : `int`, required.
The inclusive index of the end of the span in the :class:`SequenceField`.
sequence_field : `SequenceField`, required.
A field containing the sequence that this `SpanField` is a span inside.
"""
__slots__ = ["span_start", "span_end", "sequence_field"]
def __init__(self, span_start: int, span_end: int, sequence_field: SequenceField) -> None:
self.span_start = span_start
self.span_end = span_end
self.sequence_field = sequence_field
if not isinstance(span_start, int) or not isinstance(span_end, int):
raise TypeError(
f"SpanFields must be passed integer indices. Found span indices: "
f"({span_start}, {span_end}) with types "
f"({type(span_start)} {type(span_end)})"
)
if span_start > span_end:
raise ValueError(
f"span_start must be less than span_end, " f"but found ({span_start}, {span_end})."
)
if span_end > self.sequence_field.sequence_length() - 1:
raise ValueError(
f"span_end must be <= len(sequence_length) - 1, but found "
f"{span_end} and {self.sequence_field.sequence_length() - 1} respectively."
)
@overrides
def get_padding_lengths(self) -> Dict[str, int]:
return {}
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
tensor = torch.LongTensor([self.span_start, self.span_end])
return tensor
@overrides
def empty_field(self):
return SpanField(-1, -1, self.sequence_field.empty_field())
def __str__(self) -> str:
return f"SpanField with spans: ({self.span_start}, {self.span_end})."
def __eq__(self, other) -> bool:
if isinstance(other, tuple) and len(other) == 2:
return other == (self.span_start, self.span_end)
return super().__eq__(other)
def __len__(self):
return 2