This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
namespace_swapping_field.py
66 lines (51 loc) · 2.16 KB
/
namespace_swapping_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Dict, List, Any
from overrides import overrides
import torch
from allennlp.common.util import pad_sequence_to_length
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers import Token
from allennlp.data.fields.field import Field
class NamespaceSwappingField(Field[torch.Tensor]):
"""
A `NamespaceSwappingField` is used to map tokens in one namespace to tokens in another namespace.
It is used by seq2seq models with a copy mechanism that copies tokens from the source
sentence into the target sentence.
# Parameters
source_tokens : `List[Token]`
The tokens from the source sentence.
target_namespace : `str`
The namespace that the tokens from the source sentence will be mapped to.
"""
__slots__ = ["_source_tokens", "_target_namespace", "_mapping_array"]
def __init__(self, source_tokens: List[Token], target_namespace: str) -> None:
self._source_tokens = source_tokens
self._target_namespace = target_namespace
self._mapping_array: List[int] = []
@overrides
def index(self, vocab: Vocabulary):
self._mapping_array = [
vocab.get_token_index(x.ensure_text(), self._target_namespace)
for x in self._source_tokens
]
@overrides
def get_padding_lengths(self) -> Dict[str, int]:
return {"num_tokens": len(self._source_tokens)}
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
desired_length = padding_lengths["num_tokens"]
padded_tokens = pad_sequence_to_length(self._mapping_array, desired_length)
tensor = torch.LongTensor(padded_tokens)
return tensor
@overrides
def empty_field(self) -> "NamespaceSwappingField":
empty_field = NamespaceSwappingField([], self._target_namespace)
empty_field._mapping_array = []
return empty_field
def __len__(self):
return len(self._source_tokens)
@overrides
def human_readable_repr(self) -> Dict[str, Any]:
return {
"source_tokens": [str(t) for t in self._source_tokens],
"target_namespace": self._target_namespace,
}