This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
adjacency_field.py
164 lines (138 loc) · 6.47 KB
/
adjacency_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from typing import Dict, List, Set, Tuple, Optional
import logging
import textwrap
from overrides import overrides
import torch
from allennlp.common.checks import ConfigurationError
from allennlp.data.fields.field import Field
from allennlp.data.fields.sequence_field import SequenceField
from allennlp.data.vocabulary import Vocabulary
logger = logging.getLogger(__name__)
class AdjacencyField(Field[torch.Tensor]):
"""
A `AdjacencyField` defines directed adjacency relations between elements
in a :class:`~allennlp.data.fields.sequence_field.SequenceField`.
Because it's a labeling of some other field, we take that field as input here
and use it to determine our padding and other things.
This field will get converted into an array of shape (sequence_field_length, sequence_field_length),
where the (i, j)th array element is either a binary flag indicating there is an edge from i to j,
or an integer label k, indicating there is a label from i to j of type k.
# Parameters
indices : `List[Tuple[int, int]]`
sequence_field : `SequenceField`
A field containing the sequence that this `AdjacencyField` is labeling. Most often,
this is a `TextField`, for tagging edge relations between tokens in a sentence.
labels : `List[str]`, optional, (default = `None`)
Optional labels for the edges of the adjacency matrix.
label_namespace : `str`, optional (default=`'labels'`)
The namespace to use for converting tag strings into integers. We convert tag strings to
integers for you, and this parameter tells the `Vocabulary` object which mapping from
strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word).
padding_value : `int`, optional (default = `-1`)
The value to use as padding.
"""
__slots__ = [
"indices",
"labels",
"sequence_field",
"_label_namespace",
"_padding_value",
"_indexed_labels",
]
# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
# This warning will be repeated for every instantiation of this class (i.e for every data
# instance), spewing a lot of warnings so this class variable is used to only log a single
# warning per namespace.
_already_warned_namespaces: Set[str] = set()
def __init__(
self,
indices: List[Tuple[int, int]],
sequence_field: SequenceField,
labels: List[str] = None,
label_namespace: str = "labels",
padding_value: int = -1,
) -> None:
self.indices = indices
self.labels = labels
self.sequence_field = sequence_field
self._label_namespace = label_namespace
self._padding_value = padding_value
self._indexed_labels: Optional[List[int]] = None
self._maybe_warn_for_namespace(label_namespace)
field_length = sequence_field.sequence_length()
if len(set(indices)) != len(indices):
raise ConfigurationError(f"Indices must be unique, but found {indices}")
if not all(
0 <= index[1] < field_length and 0 <= index[0] < field_length for index in indices
):
raise ConfigurationError(
f"Label indices and sequence length "
f"are incompatible: {indices} and {field_length}"
)
if labels is not None and len(indices) != len(labels):
raise ConfigurationError(
f"Labelled indices were passed, but their lengths do not match: "
f" {labels}, {indices}"
)
def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
if label_namespace not in self._already_warned_namespaces:
logger.warning(
"Your label namespace was '%s'. We recommend you use a namespace "
"ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
"default to your vocabulary. See documentation for "
"`non_padded_namespaces` parameter in Vocabulary.",
self._label_namespace,
)
self._already_warned_namespaces.add(label_namespace)
@overrides
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
if self._indexed_labels is None and self.labels is not None:
for label in self.labels:
counter[self._label_namespace][label] += 1 # type: ignore
@overrides
def index(self, vocab: Vocabulary):
if self.labels is not None:
self._indexed_labels = [
vocab.get_token_index(label, self._label_namespace) for label in self.labels
]
@overrides
def get_padding_lengths(self) -> Dict[str, int]:
return {"num_tokens": self.sequence_field.sequence_length()}
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
desired_num_tokens = padding_lengths["num_tokens"]
tensor = torch.ones(desired_num_tokens, desired_num_tokens) * self._padding_value
labels = self._indexed_labels or [1 for _ in range(len(self.indices))]
for index, label in zip(self.indices, labels):
tensor[index] = label
return tensor
@overrides
def empty_field(self) -> "AdjacencyField":
# The empty_list here is needed for mypy
empty_list: List[Tuple[int, int]] = []
adjacency_field = AdjacencyField(
empty_list, self.sequence_field.empty_field(), padding_value=self._padding_value
)
return adjacency_field
def __str__(self) -> str:
length = self.sequence_field.sequence_length()
formatted_labels = "".join(
"\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.labels), 100)
)
formatted_indices = "".join(
"\t\t" + index + "\n" for index in textwrap.wrap(repr(self.indices), 100)
)
return (
f"AdjacencyField of length {length}\n"
f"\t\twith indices:\n {formatted_indices}\n"
f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
)
def __len__(self):
return len(self.sequence_field)
@overrides
def human_readable_repr(self):
ret = {"indices": self.indices}
if self.labels is not None:
ret["labels"] = self.labels
return ret