This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
interleaving_dataset_reader.py
99 lines (80 loc) · 4 KB
/
interleaving_dataset_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Dict, Mapping, Iterable
import json
from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import MetadataField
from allennlp.data.instance import Instance
_VALID_SCHEMES = {"round_robin", "all_at_once"}
@DatasetReader.register("interleaving")
class InterleavingDatasetReader(DatasetReader):
"""
A `DatasetReader` that wraps multiple other dataset readers,
and interleaves their instances, adding a `MetadataField` to
indicate the provenance of each instance.
Unlike most of our other dataset readers, here the `file_path` passed into
`read()` should be a JSON-serialized dictionary with one file_path
per wrapped dataset reader (and with corresponding keys).
Registered as a `DatasetReader` with name "interleaving".
# Parameters
readers : `Dict[str, DatasetReader]`
The dataset readers to wrap. The keys of this dictionary will be used
as the values in the MetadataField indicating provenance.
dataset_field_name : `str`, optional (default = `"dataset"`)
The name of the MetadataField indicating which dataset an instance came from.
scheme : `str`, optional (default = `"round_robin"`)
Indicates how to interleave instances. Currently the two options are "round_robin",
which repeatedly cycles through the datasets grabbing one instance from each;
and "all_at_once", which yields all the instances from the first dataset,
then all the instances from the second dataset, and so on. You could imagine also
implementing some sort of over- or under-sampling, although hasn't been done.
"""
def __init__(
self,
readers: Dict[str, DatasetReader],
dataset_field_name: str = "dataset",
scheme: str = "round_robin",
**kwargs,
) -> None:
super().__init__(**kwargs)
self._readers = readers
self._dataset_field_name = dataset_field_name
if scheme not in _VALID_SCHEMES:
raise ConfigurationError(f"invalid scheme: {scheme}")
self._scheme = scheme
def _read_round_robin(self, datasets: Mapping[str, Iterable[Instance]]) -> Iterable[Instance]:
remaining = set(datasets)
dataset_iterators = {key: iter(dataset) for key, dataset in datasets.items()}
while remaining:
for key, dataset in dataset_iterators.items():
if key in remaining:
try:
instance = next(dataset)
instance.fields[self._dataset_field_name] = MetadataField(key)
yield instance
except StopIteration:
remaining.remove(key)
def _read_all_at_once(self, datasets: Mapping[str, Iterable[Instance]]) -> Iterable[Instance]:
for key, dataset in datasets.items():
for instance in dataset:
instance.fields[self._dataset_field_name] = MetadataField(key)
yield instance
def _read(self, file_path: str) -> Iterable[Instance]:
try:
file_paths = json.loads(file_path)
except json.JSONDecodeError:
raise ConfigurationError(
"the file_path for the InterleavingDatasetReader "
"needs to be a JSON-serialized dictionary {reader_name -> file_path}"
)
if file_paths.keys() != self._readers.keys():
raise ConfigurationError("mismatched keys")
# Load datasets
datasets = {key: reader.read(file_paths[key]) for key, reader in self._readers.items()}
if self._scheme == "round_robin":
yield from self._read_round_robin(datasets)
elif self._scheme == "all_at_once":
yield from self._read_all_at_once(datasets)
else:
raise RuntimeError("impossible to get here")
def text_to_instance(self) -> Instance: # type: ignore
raise RuntimeError("text_to_instance doesn't make sense here")