This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
interleaving_dataset_reader.py
127 lines (105 loc) · 5.01 KB
/
interleaving_dataset_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Dict, Mapping, Iterable, Union, Optional
import json
from overrides import overrides
from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers.dataset_reader import (
DatasetReader,
PathOrStr,
WorkerInfo,
DistributedInfo,
)
from allennlp.data.fields import MetadataField
from allennlp.data.instance import Instance
_VALID_SCHEMES = {"round_robin", "all_at_once"}
@DatasetReader.register("interleaving")
class InterleavingDatasetReader(DatasetReader):
"""
A `DatasetReader` that wraps multiple other dataset readers,
and interleaves their instances, adding a `MetadataField` to
indicate the provenance of each instance.
Unlike most of our other dataset readers, here the `file_path` passed into
`read()` should be a JSON-serialized dictionary with one file_path
per wrapped dataset reader (and with corresponding keys).
Registered as a `DatasetReader` with name "interleaving".
# Parameters
readers : `Dict[str, DatasetReader]`
The dataset readers to wrap. The keys of this dictionary will be used
as the values in the MetadataField indicating provenance.
dataset_field_name : `str`, optional (default = `"dataset"`)
The name of the MetadataField indicating which dataset an instance came from.
scheme : `str`, optional (default = `"round_robin"`)
Indicates how to interleave instances. Currently the two options are "round_robin",
which repeatedly cycles through the datasets grabbing one instance from each;
and "all_at_once", which yields all the instances from the first dataset,
then all the instances from the second dataset, and so on. You could imagine also
implementing some sort of over- or under-sampling, although hasn't been done.
"""
def __init__(
self,
readers: Dict[str, DatasetReader],
dataset_field_name: str = "dataset",
scheme: str = "round_robin",
**kwargs,
) -> None:
super().__init__(**kwargs)
self._readers = readers
self._dataset_field_name = dataset_field_name
if scheme not in _VALID_SCHEMES:
raise ConfigurationError(f"invalid scheme: {scheme}")
self._scheme = scheme
@overrides
def _set_worker_info(self, info: Optional[WorkerInfo]) -> None:
super()._set_worker_info(info)
for reader in self._readers.values():
reader._set_worker_info(info)
@overrides
def _set_distributed_info(self, info: Optional[DistributedInfo]) -> None:
super()._set_distributed_info(info)
for reader in self._readers.values():
reader._set_distributed_info(info)
def _read_round_robin(self, datasets: Mapping[str, Iterable[Instance]]) -> Iterable[Instance]:
remaining = set(datasets)
dataset_iterators = {key: iter(dataset) for key, dataset in datasets.items()}
while remaining:
for key, dataset in dataset_iterators.items():
if key in remaining:
try:
instance = next(dataset)
instance.fields[self._dataset_field_name] = MetadataField(key)
yield instance
except StopIteration:
remaining.remove(key)
def _read_all_at_once(self, datasets: Mapping[str, Iterable[Instance]]) -> Iterable[Instance]:
for key, dataset in datasets.items():
for instance in dataset:
instance.fields[self._dataset_field_name] = MetadataField(key)
yield instance
@overrides
def _read(self, file_path: Union[str, Dict[str, PathOrStr]]) -> Iterable[Instance]:
if isinstance(file_path, str):
try:
file_paths = json.loads(file_path)
except json.JSONDecodeError:
raise ConfigurationError(
"the file_path for the InterleavingDatasetReader "
"needs to be a JSON-serialized dictionary {reader_name -> file_path}"
)
else:
file_paths = file_path
if file_paths.keys() != self._readers.keys():
raise ConfigurationError("mismatched keys")
# Load datasets
datasets = {key: reader.read(file_paths[key]) for key, reader in self._readers.items()}
if self._scheme == "round_robin":
yield from self._read_round_robin(datasets)
elif self._scheme == "all_at_once":
yield from self._read_all_at_once(datasets)
else:
raise RuntimeError("impossible to get here")
@overrides
def text_to_instance(self, dataset_key: str, *args, **kwargs) -> Instance: # type: ignore
return self._readers[dataset_key].text_to_instance(*args, **kwargs) # type: ignore[call-arg]
@overrides
def apply_token_indexers(self, instance: Instance) -> None:
dataset = instance.fields[self._dataset_field_name].metadata # type: ignore[attr-defined]
self._readers[dataset].apply_token_indexers(instance)