forked from Fsoft-AIC/RoPADet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mask_tokens_dataset.py
220 lines (193 loc) · 8.57 KB
/
mask_tokens_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import numpy as np
import torch
from fairseq.data import Dictionary, data_utils
from . import BaseWrapperDataset, LRUCacheDataset
class MaskTokensDataset(BaseWrapperDataset):
"""
A wrapper Dataset for masked language modeling.
Input items are masked according to the specified masking probability.
Args:
dataset: Dataset to wrap.
sizes: Sentence lengths
vocab: Dictionary with the vocabulary and special tokens.
pad_idx: Id of pad token in vocab
mask_idx: Id of mask token in vocab
return_masked_tokens: controls whether to return the non-masked tokens
(the default) or to return a tensor with the original masked token
IDs (and *pad_idx* elsewhere). The latter is useful as targets for
masked LM training.
seed: Seed for random number generator for reproducibility.
mask_prob: probability of replacing a token with *mask_idx*.
leave_unmasked_prob: probability that a masked token is unmasked.
random_token_prob: probability of replacing a masked token with a
random token from the vocabulary.
freq_weighted_replacement: sample random replacement words based on
word frequencies in the vocab.
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
bpe: BPE to use for whole-word masking.
mask_multiple_length : repeat each mask index multiple times. Default
value is 1.
mask_stdev : standard deviation of masks distribution in case of
multiple masking. Default value is 0.
"""
@classmethod
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
"""Return the source and target datasets for masked LM training."""
dataset = LRUCacheDataset(dataset)
return (
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
)
def __init__(
self,
dataset: torch.utils.data.Dataset,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
mask_whole_words: torch.Tensor = None,
mask_multiple_length: int = 1,
mask_stdev: float = 0.0,
):
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
assert 0.0 <= leave_unmasked_prob <= 1.0
assert random_token_prob + leave_unmasked_prob <= 1.0
assert mask_multiple_length >= 1
assert mask_stdev >= 0.0
self.dataset = dataset
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.return_masked_tokens = return_masked_tokens
self.seed = seed
self.mask_prob = mask_prob
self.leave_unmasked_prob = leave_unmasked_prob
self.random_token_prob = random_token_prob
self.mask_whole_words = mask_whole_words
self.mask_multiple_length = mask_multiple_length
self.mask_stdev = mask_stdev
if random_token_prob > 0.0:
if freq_weighted_replacement:
weights = np.array(self.vocab.count)
else:
weights = np.ones(len(self.vocab))
weights[: self.vocab.nspecial] = 0
self.weights = weights / weights.sum()
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index: int):
return self.__getitem_cached__(self.seed, self.epoch, index)
@lru_cache(maxsize=8)
def __getitem_cached__(self, seed: int, epoch: int, index: int):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
sz = len(item)
assert (
self.mask_idx not in item
), "Dataset contains mask_idx (={}), this is not expected!".format(
self.mask_idx,
)
if self.mask_whole_words is not None:
word_begins_mask = self.mask_whole_words.gather(0, item)
word_begins_idx = word_begins_mask.nonzero().view(-1)
sz = len(word_begins_idx)
words = np.split(word_begins_mask, word_begins_idx)[1:]
assert len(words) == sz
word_lens = list(map(len, words))
# decide elements to mask
mask = np.full(sz, False)
num_mask = int(
# add a random number for probabilistic rounding
self.mask_prob * sz / float(self.mask_multiple_length)
+ np.random.rand()
)
# multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
mask_idc = np.random.choice(sz, num_mask, replace=False)
if self.mask_stdev > 0.0:
lengths = np.random.normal(
self.mask_multiple_length, self.mask_stdev, size=num_mask
)
lengths = [max(0, int(round(x))) for x in lengths]
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
],
dtype=np.int64,
)
else:
mask_idc = np.concatenate(
[mask_idc + i for i in range(self.mask_multiple_length)]
)
mask_idc = mask_idc[mask_idc < len(mask)]
try:
mask[mask_idc] = True
except: # something wrong
print(
"Assigning mask indexes {} to mask {} failed!".format(
mask_idc, mask
)
)
raise
if self.return_masked_tokens:
# exit early if we're just returning the masked tokens
# (i.e., the targets for masked LM training)
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.full(len(mask), self.pad_idx)
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
return torch.from_numpy(new_item)
# decide unmasking and random replacement
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
if rand_or_unmask_prob > 0.0:
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
if self.random_token_prob == 0.0:
unmask = rand_or_unmask
rand_mask = None
elif self.leave_unmasked_prob == 0.0:
unmask = None
rand_mask = rand_or_unmask
else:
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
decision = np.random.rand(sz) < unmask_prob
unmask = rand_or_unmask & decision
rand_mask = rand_or_unmask & (~decision)
else:
unmask = rand_mask = None
if unmask is not None:
mask = mask ^ unmask
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.copy(item)
new_item[mask] = self.mask_idx
if rand_mask is not None:
num_rand = rand_mask.sum()
if num_rand > 0:
if self.mask_whole_words is not None:
rand_mask = np.repeat(rand_mask, word_lens)
num_rand = rand_mask.sum()
new_item[rand_mask] = np.random.choice(
len(self.vocab),
num_rand,
p=self.weights,
)
return torch.from_numpy(new_item)