-
Notifications
You must be signed in to change notification settings - Fork 35
/
dataset_default.py
383 lines (319 loc) · 16.3 KB
/
dataset_default.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
"""
(C) Copyright 2021 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on June 30, 2021
"""
from typing import Dict, Hashable, List, Optional, Sequence, Union
from warnings import warn
from fuse.data.datasets.dataset_base import DatasetBase
from fuse.data.ops.ops_common import OpCollectMarker
from fuse.data.pipelines.pipeline_default import PipelineDefault
from fuse.data.datasets.caching.samples_cacher import SamplesCacher
from fuse.utils.ndict import NDict
from fuse.utils.multiprocessing.run_multiprocessed import (
run_multiprocessed,
get_from_global_storage,
)
from fuse.data import get_sample_id, create_initial_sample, get_specific_sample_from_potentially_morphed
import copy
from collections import OrderedDict
import numpy as np
from operator import itemgetter
class DatasetDefault(DatasetBase):
def __init__(
self,
sample_ids: Union[int, Sequence[Hashable], None],
static_pipeline: Optional[PipelineDefault] = None,
dynamic_pipeline: Optional[PipelineDefault] = None,
cacher: Optional[SamplesCacher] = None,
allow_uncached_sample_morphing: bool = False,
):
"""
:param sample_ids: list of sample_ids included in dataset. Or:
- An integer that describes only the size of the dataset. This is useful in massive datasets
(for example 100M samples). In such case, multiple functionalities will not be supported, mainly -
cacher, allow_uncached_sample_morphing and get_all_sample_ids
- None. In this case, the dataset will not deal with sample ids. it is the user's respobsibility to handle
iterations w.r.t the length of the dataset, as well as the index passed to __getitem__
this is useful for massive datasets, but when the sample ids are not expected to be running integets from 0 to a given length.
:param static_pipeline: static_pipeline, the output of this pipeline will be automatically cached.
:param dynamic_pipeline: dynamic_pipeline. applied sequentially after the static_pipeline, but not automatically cached.
changing it will NOT trigger recaching of the static_pipeline part.
:param cacher: optional SamplesCacher instance which will be used for caching samples to speed up samples loading
:param allow_uncached_sample_morphing: when enabled, allows an Op, to return None, or to return multiple samples (in a list)
"""
super().__init__()
# store arguments
self._cacher = cacher
if isinstance(sample_ids, (int, np.integer)):
if allow_uncached_sample_morphing:
raise Exception(
"allow_uncached_sample_morphing is not allowed when providing sample_ids=an integer value"
)
if cacher is not None:
raise Exception("providing a cacher is not allowed when providing sample_ids=an integer value")
self._sample_ids_mode = "running_int"
elif sample_ids is None:
self._sample_ids_mode = "external"
else:
self._sample_ids_mode = "explicit"
# self._orig_sample_ids = sample_ids
self._allow_uncached_sample_morphing = allow_uncached_sample_morphing
# verify unique names for dynamic pipelines
if dynamic_pipeline is not None and static_pipeline is not None:
if static_pipeline.get_name() == dynamic_pipeline.get_name():
raise Exception(
f"Detected identical name for static pipeline and dynamic pipeline ({static_pipeline.get_name(static_pipeline.get_name())}).\nThis is not allowed, please initiate the pipelines with different names."
)
if static_pipeline is None:
static_pipeline = PipelineDefault("dummy_static_pipeline", ops_and_kwargs=[])
if dynamic_pipeline is None:
dynamic_pipeline = PipelineDefault("dummy_dynamic_pipeline", ops_and_kwargs=[])
if dynamic_pipeline is not None:
assert isinstance(
dynamic_pipeline, PipelineDefault
), f"dynamic_pipeline may be None or a PipelineDefault instance. Instead got {type(dynamic_pipeline)}"
if static_pipeline is not None:
assert isinstance(
static_pipeline, PipelineDefault
), f"static_pipeline may be None or a PipelineDefault instance. Instead got {type(static_pipeline)}"
if self._allow_uncached_sample_morphing:
warn(
"allow_uncached_sample_morphing is enabled! It is a significantly slower mode and should be used ONLY for debugging"
)
self._static_pipeline = static_pipeline
self._dynamic_pipeline = dynamic_pipeline
self._orig_sample_ids = copy.deepcopy(sample_ids)
self._created = False
def create(self, num_workers: int = 0, mp_context: Optional[str] = None) -> None:
"""
Create the data set, including caching
:param num_workers: number of workers. used only when caching is disabled and allow_uncached_sample_morphing is enabled
set num_workers=0 to disable multiprocessing (more convenient for debugging)
Setting num_workers for caching is done in cacher constructor.
:param mp_context: "fork", "spawn", "thread" or None for multiprocessing default
:return: None
"""
self._output_sample_ids_info = None
if self._cacher is not None:
self._output_sample_ids_info = self._cacher.cache_samples(self._orig_sample_ids)
elif self._allow_uncached_sample_morphing:
_output_sample_ids_info_list = run_multiprocessed(
DatasetDefault._process_orig_sample_id,
[(sid, self._static_pipeline, False) for sid in self._orig_sample_ids],
workers=num_workers,
mp_context=mp_context,
desc="dataset_default.sample_morphing",
)
self._output_sample_ids_info = OrderedDict()
self._final_sid_to_orig_sid = {}
for sample_in_out_info in _output_sample_ids_info_list:
orig_sid, out_sids = sample_in_out_info[0], sample_in_out_info[1]
self._output_sample_ids_info[orig_sid] = out_sids
if out_sids is not None:
assert isinstance(out_sids, list)
for final_sid in out_sids:
self._final_sid_to_orig_sid[final_sid] = orig_sid
if self._output_sample_ids_info is not None: # sample morphing is allowed
self._final_sample_ids = []
for orig_sid, out_sids in self._output_sample_ids_info.items():
if out_sids is None:
continue
self._final_sample_ids.extend(out_sids)
else:
self._final_sample_ids = self._orig_sample_ids
self._orig_sample_ids = None # should not be use after create. use self._final_sample_ids instead
self._created = True
def get_all_sample_ids(self):
if not self._created:
raise Exception("you must first call create()")
if self._sample_ids_mode != "explicit":
raise Exception("get_all_sample_ids is not supported when constructed with non explicit sample_ids")
return copy.deepcopy(self._final_sample_ids)
def __getitem__(self, item: Union[int, Hashable]) -> dict:
"""
Get sample, read from cache if possible
:param item: either int representing sample index or sample_id
:return: sample_dict
"""
return self.getitem(item)
def getitem(
self,
item: Union[int, Hashable],
collect_marker_name: Optional[str] = None,
keys: Optional[Sequence[str]] = None,
) -> NDict:
"""
Get sample, read from cache if possible
:param item: either int representing sample index or sample_id
:param collect_marker_name: Optional, specify name of collect marker op to optimize the running time
:param keys: Optional, return just the specified keys or everything available if set to None
:return: sample_dict
"""
if not self._created:
raise Exception("you must first call create()")
# get sample id
if self._sample_ids_mode != "explicit":
sample_id = item
if self._sample_ids_mode == "running_int": # allow using non int sample_ids
if sample_id >= self._final_sample_ids:
raise IndexError
elif not isinstance(item, (int, np.integer)):
sample_id = item
else:
sample_id = self._final_sample_ids[item]
# get collect marker info
collect_marker_info = self._get_collect_marker_info(collect_marker_name)
# read sample
if self._cacher is not None:
sample = self._cacher.load_sample(sample_id, collect_marker_info["static_keys_deps"])
if self._cacher is None:
if not self._allow_uncached_sample_morphing:
sample = create_initial_sample(sample_id)
sample = self._static_pipeline(sample)
if not isinstance(sample, dict):
raise Exception(
f'By default when caching is disabled sample morphing is not allowed, and the output of the static pipeline is expected to be a dict. Instead got {type(sample)}. You can use "allow_uncached_sample_morphing=True" to allow this, but be aware it is slow and should be used only for debugging'
)
else:
orig_sid = self._final_sid_to_orig_sid[sample_id]
sample = create_initial_sample(orig_sid)
sample = self._static_pipeline(sample)
assert sample is not None
sample = get_specific_sample_from_potentially_morphed(sample, sample_id)
sample = self._dynamic_pipeline(sample, until_op_id=collect_marker_info["op_id"])
if not isinstance(sample, dict):
raise Exception(
f"The final output of dataset static (+optional dynamic) pipelines is expected to be a dict. Instead got {type(sample)}"
)
# get just required keys
if keys is not None:
sample = sample.get_multi(keys)
return sample
def _get_multi_multiprocess_func(self, args):
sid, kwargs = args
return self.getitem(sid, **kwargs)
@staticmethod
def _getitem_multiprocess(item: Union[Hashable, int, np.integer]):
"""
getitem method used to optimize the running time in a multiprocess mode
"""
dataset = get_from_global_storage("dataset_default_get_multi_dataset")
kwargs = get_from_global_storage("dataset_default_get_multi_kwargs")
return dataset.getitem(item, **kwargs)
def get_multi(
self,
items: Optional[Sequence[Union[int, Hashable]]] = None,
workers: int = 10,
verbose: int = 1,
mp_context: Optional[str] = None,
desc: str = "dataset_default.get_multi",
**kwargs,
) -> List[Dict]:
"""
See super class
:param workers: number of processes to read the data. set to 0 to not use multi processing (useful when debugging).
:param mp_context: "fork", "spawn", "thread" or None for multiprocessing default
"""
if items is None:
sample_ids = list(range(len(self)))
else:
sample_ids = items
for_global_storage = {"dataset_default_get_multi_dataset": self, "dataset_default_get_multi_kwargs": kwargs}
list_sample_dict = run_multiprocessed(
worker_func=self._getitem_multiprocess,
copy_to_global_storage=for_global_storage,
args_list=sample_ids,
workers=workers,
verbose=verbose,
mp_context=mp_context,
desc=desc,
)
return list_sample_dict
def __len__(self):
if not self._created:
raise Exception("you must first call create()")
if self._sample_ids_mode == "running_int":
return self._final_sample_ids
elif self._sample_ids_mode == "external":
raise Exception("__len__ is not defined where explicit sample_ids or an interer len are not provided.")
return len(self._final_sample_ids)
# internal methods
@staticmethod
def _process_orig_sample_id(args):
"""
Process, without caching, single sample
"""
orig_sample_id, pipeline, return_sample_dict = args
sample = create_initial_sample(orig_sample_id)
sample = pipeline(sample)
output_sample_ids = None
if sample is not None:
output_sample_ids = []
if not isinstance(sample, list):
sample = [sample]
for curr_sample in sample:
output_sample_ids.append(get_sample_id(curr_sample))
if not return_sample_dict:
return orig_sample_id, output_sample_ids
return orig_sample_id, output_sample_ids, sample
def _get_collect_marker_info(self, collect_marker_name: str):
"""
Find the required collect marker (OpCollectMarker in the dynamic pipeline).
See OpCollectMarker for more details
:param collect_marker_name: name to identify the required collect marker
:return: a dictionary with the required info - including: name, op_id and static_keys_deps.
if collect_marker_name is None will return default instruct to run the entire dynamic pipeline
"""
# default values for case collect marker info is not used
if collect_marker_name is None:
return {"name": None, "op_id": None, "static_keys_deps": None}
# find the required collect markers and extract the info
collect_marker_info = None
for (op, _), op_id in reversed(zip(self._dynamic_pipeline.ops_and_kwargs, self._dynamic_pipeline._op_ids)):
if isinstance(op, OpCollectMarker):
collect_marker_info_cur = op.get_info()
if collect_marker_info_cur["name"] == collect_marker_name:
if collect_marker_info is None:
collect_marker_info = collect_marker_info_cur
collect_marker_info["op_id"] = op_id
# continue to make sure this is the only one
else:
# throw an error if found more than one collect marker
raise Exception(
f"Error: two collect markers with name {collect_marker_info} found in dynamic pipeline"
)
if collect_marker_info is None:
raise Exception(f"Error: didn't find collect marker with name {collect_marker_info} in dynamic pipeline.")
return collect_marker_info
def summary(self) -> str:
sum = ""
sum += f"Type: {type(self).__name__}\n"
sum += f"Num samples: {len(self._final_sample_ids)}\n"
# TODO
# sum += f"Cacher: {self._cacher.summary()}"
# sum += f"Pipeline static: {self._static_pipeline.summary()}"
# sum += f"Pipeline dynamic: {self._dynamic_pipeline.summary()}"
return sum
def subset(self, indices: Sequence[int]) -> None:
"""
create a subset of the dataset by a given indices (inplace).
Example:
For the dataset '[-2, 1, 5, 3, 8, 5, 6]' and the indices '[1, 2, 5]', the subset is [1, 5, 5]
:param items: indices of the subset - if None, the subset is the whole set.
"""
if indices is None:
# Do nothing, the subset is the whole dataset
return
if not self._created:
raise Exception("you must first call create()")
# grab the specified data
self._final_sample_ids = itemgetter(*indices)(self._final_sample_ids)