forked from Fsoft-AIC/RoPADet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
add_target_dataset.py
79 lines (67 loc) · 2.66 KB
/
add_target_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset, data_utils
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
class AddTargetDataset(BaseWrapperDataset):
def __init__(
self,
dataset,
labels,
pad,
eos,
batch_targets,
process_label=None,
label_len_fn=None,
add_to_input=False,
text_compression_level=TextCompressionLevel.none,
):
super().__init__(dataset)
self.labels = labels
self.batch_targets = batch_targets
self.pad = pad
self.eos = eos
self.process_label = process_label
self.label_len_fn = label_len_fn
self.add_to_input = add_to_input
self.text_compressor = TextCompressor(level=text_compression_level)
def get_label(self, index, process_fn=None):
lbl = self.labels[index]
lbl = self.text_compressor.decompress(lbl)
return lbl if process_fn is None else process_fn(lbl)
def __getitem__(self, index):
item = self.dataset[index]
item["label"] = self.get_label(index, process_fn=self.process_label)
return item
def size(self, index):
sz = self.dataset.size(index)
own_sz = self.label_len_fn(self.get_label(index))
return sz, own_sz
def collater(self, samples):
collated = self.dataset.collater(samples)
if len(collated) == 0:
return collated
indices = set(collated["id"].tolist())
target = [s["label"] for s in samples if s["id"] in indices]
if self.batch_targets:
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
collated["ntokens"] = collated["target_lengths"].sum().item()
else:
collated["ntokens"] = sum([len(t) for t in target])
collated["target"] = target
if self.add_to_input:
eos = target.new_full((target.size(0), 1), self.eos)
collated["target"] = torch.cat([target, eos], dim=-1).long()
collated["net_input"]["prev_output_tokens"] = torch.cat(
[eos, target], dim=-1
).long()
collated["ntokens"] += target.size(0)
return collated
def filter_indices_by_size(self, indices, max_sizes):
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
return indices, ignored