-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
dynamic_iterator.py
210 lines (182 loc) · 8.14 KB
/
dynamic_iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""Module that contain iterator used for dynamic data."""
from itertools import cycle
from torchtext.data import batch as torchtext_batch
from onmt.inputters import str2sortkey, max_tok_len, OrderedIterator
from onmt.inputters.corpus import get_corpora, build_corpora_iters,\
DatasetAdapter
from onmt.transforms import make_transforms
from onmt.utils.logging import logger
class MixingStrategy(object):
"""Mixing strategy that should be used in Data Iterator."""
def __init__(self, iterables, weights):
"""Initilize neccessary attr."""
self._valid_iterable(iterables, weights)
self.iterables = iterables
self.weights = weights
def _valid_iterable(self, iterables, weights):
iter_keys = iterables.keys()
weight_keys = weights.keys()
if iter_keys != weight_keys:
raise ValueError(
f"keys in {iterables} & {iterables} should be equal.")
def __iter__(self):
raise NotImplementedError
class SequentialMixer(MixingStrategy):
"""Generate data sequentially from `iterables` which is exhaustible."""
def _iter_datasets(self):
for ds_name, ds_weight in self.weights.items():
for _ in range(ds_weight):
yield ds_name
def __iter__(self):
for ds_name in self._iter_datasets():
iterable = self.iterables[ds_name]
yield from iterable
class WeightedMixer(MixingStrategy):
"""A mixing strategy that mix data weightedly and iterate infinitely."""
def __init__(self, iterables, weights):
super().__init__(iterables, weights)
self._iterators = {}
self._counts = {}
for ds_name in self.iterables.keys():
self._reset_iter(ds_name)
def _logging(self):
"""Report corpora loading statistics."""
msgs = []
for ds_name, ds_count in self._counts.items():
msgs.append(f"\t\t\t* {ds_name}: {ds_count}")
logger.info("Weighted corpora loaded so far:\n"+"\n".join(msgs))
def _reset_iter(self, ds_name):
self._iterators[ds_name] = iter(self.iterables[ds_name])
self._counts[ds_name] = self._counts.get(ds_name, 0) + 1
self._logging()
def _iter_datasets(self):
for ds_name, ds_weight in self.weights.items():
for _ in range(ds_weight):
yield ds_name
def __iter__(self):
for ds_name in cycle(self._iter_datasets()):
iterator = self._iterators[ds_name]
try:
item = next(iterator)
except StopIteration:
self._reset_iter(ds_name)
iterator = self._iterators[ds_name]
item = next(iterator)
finally:
yield item
class DynamicDatasetIter(object):
"""Yield batch from (multiple) plain text corpus.
Args:
corpora (dict[str, ParallelCorpus]): collections of corpora to iterate;
corpora_info (dict[str, dict]): corpora infos correspond to corpora;
transforms (dict[str, Transform]): transforms may be used by corpora;
fields (dict[str, Field]): fields dict for convert corpora into Tensor;
is_train (bool): True when generate data for training;
batch_type (str): batching type to count on, choices=[tokens, sents];
batch_size (int): numbers of examples in a batch;
batch_size_multiple (int): make batch size multiply of this;
data_type (str): input data type, currently only text;
bucket_size (int): accum this number of examples in a dynamic dataset;
pool_factor (int): accum this number of batch before sorting;
skip_empty_level (str): security level when encouter empty line;
stride (int): iterate data files with this stride;
offset (int): iterate data files with this offset.
Attributes:
batch_size_fn (function): functions to calculate batch_size;
sort_key (function): functions define how to sort examples;
dataset_adapter (DatasetAdapter): organize raw corpus to tensor adapt;
mixer (MixingStrategy): the strategy to iterate corpora.
"""
def __init__(self, corpora, corpora_info, transforms, fields, is_train,
batch_type, batch_size, batch_size_multiple, data_type="text",
bucket_size=2048, pool_factor=8192,
skip_empty_level='warning', stride=1, offset=0):
self.corpora = corpora
self.transforms = transforms
self.fields = fields
self.corpora_info = corpora_info
self.is_train = is_train
self.init_iterators = False
self.batch_size = batch_size
self.batch_size_fn = max_tok_len if batch_type == "tokens" else None
self.batch_size_multiple = batch_size_multiple
self.device = 'cpu'
self.sort_key = str2sortkey[data_type]
self.bucket_size = bucket_size
self.pool_factor = pool_factor
if stride <= 0:
raise ValueError(f"Invalid argument for stride={stride}.")
self.stride = stride
self.offset = offset
if skip_empty_level not in ['silent', 'warning', 'error']:
raise ValueError(
f"Invalid argument skip_empty_level={skip_empty_level}")
self.skip_empty_level = skip_empty_level
@classmethod
def from_opts(cls, corpora, transforms, fields, opts, is_train,
stride=1, offset=0):
"""Initilize `DynamicDatasetIter` with options parsed from `opts`."""
batch_size = opts.batch_size if is_train else opts.valid_batch_size
if opts.batch_size_multiple is not None:
batch_size_multiple = opts.batch_size_multiple
else:
batch_size_multiple = 8 if opts.model_dtype == "fp16" else 1
return cls(
corpora, opts.data, transforms, fields, is_train, opts.batch_type,
batch_size, batch_size_multiple, data_type=opts.data_type,
bucket_size=opts.bucket_size, pool_factor=opts.pool_factor,
skip_empty_level=opts.skip_empty_level,
stride=stride, offset=offset
)
def _init_datasets(self):
datasets_iterables = build_corpora_iters(
self.corpora, self.transforms, self.corpora_info,
skip_empty_level=self.skip_empty_level,
stride=self.stride, offset=self.offset)
self.dataset_adapter = DatasetAdapter(self.fields, self.is_train)
datasets_weights = {
ds_name: int(self.corpora_info[ds_name]['weight'])
for ds_name in datasets_iterables.keys()
}
if self.is_train:
self.mixer = WeightedMixer(datasets_iterables, datasets_weights)
else:
self.mixer = SequentialMixer(datasets_iterables, datasets_weights)
self.init_iterators = True
def _bucketing(self):
buckets = torchtext_batch(
self.mixer,
batch_size=self.bucket_size,
batch_size_fn=None)
yield from buckets
def __iter__(self):
if self.init_iterators is False:
self._init_datasets()
for bucket in self._bucketing():
dataset = self.dataset_adapter(bucket)
train_iter = OrderedIterator(
dataset,
self.batch_size,
pool_factor=self.pool_factor,
batch_size_fn=self.batch_size_fn,
batch_size_multiple=self.batch_size_multiple,
device=self.device,
train=self.is_train,
sort=False,
sort_within_batch=True,
sort_key=self.sort_key,
repeat=False,
)
for batch in train_iter:
yield batch
def build_dynamic_dataset_iter(fields, transforms_cls, opts, is_train=True,
stride=1, offset=0):
"""Build `DynamicDatasetIter` from fields & opts."""
transforms = make_transforms(opts, transforms_cls, fields)
corpora = get_corpora(opts, is_train)
if corpora is None:
assert not is_train, "only valid corpus is ignorable."
return None
return DynamicDatasetIter.from_opts(
corpora, transforms, fields, opts, is_train,
stride=stride, offset=offset)