Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
resolve _read type (#4916)
Browse files Browse the repository at this point in the history
* resolve _read type

* fix sharded reader

* fix data loader arg
  • Loading branch information
epwalsh committed Jan 19, 2021
1 parent 5229da8 commit 0f00d4d
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 36 deletions.
7 changes: 3 additions & 4 deletions allennlp/data/data_loaders/multiprocess_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import deque
import logging
from multiprocessing.process import BaseProcess
from os import PathLike
import random
import traceback
from typing import List, Iterator, Optional, Iterable, Union
Expand All @@ -14,7 +13,7 @@
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate
from allennlp.data.dataset_readers import DatasetReader, WorkerInfo
from allennlp.data.dataset_readers import DatasetReader, WorkerInfo, DatasetReaderInput
from allennlp.data.fields import TextField
from allennlp.data.samplers import BatchSampler
from allennlp.data.vocabulary import Vocabulary
Expand All @@ -39,7 +38,7 @@ class MultiProcessDataLoader(DataLoader):
reader: `DatasetReader`, required
A `DatasetReader` used to load instances from the `data_path`.
data_path: `Union[str, PathLike]`, required
data_path: `DatasetReaderInput`, required
Passed to `DatasetReader.read()`.
!!! Note
Expand Down Expand Up @@ -190,7 +189,7 @@ class MultiProcessDataLoader(DataLoader):
def __init__(
self,
reader: DatasetReader,
data_path: Union[str, PathLike],
data_path: DatasetReaderInput,
*,
batch_size: int = None,
drop_last: bool = False,
Expand Down
1 change: 1 addition & 0 deletions allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from allennlp.data.dataset_readers.dataset_reader import (
DatasetReader,
WorkerInfo,
DatasetReaderInput,
)
from allennlp.data.dataset_readers.babi import BabiReader
from allennlp.data.dataset_readers.conll2003 import Conll2003DatasetReader
Expand Down
4 changes: 2 additions & 2 deletions allennlp/data/dataset_readers/babi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from overrides import overrides

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, PathOrStr
from allennlp.data.instance import Instance
from allennlp.data.fields import Field, TextField, ListField, IndexField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}

@overrides
def _read(self, file_path: str):
def _read(self, file_path: PathOrStr):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)

Expand Down
4 changes: 2 additions & 2 deletions allennlp/data/dataset_readers/conll2003.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, PathOrStr
from allennlp.data.dataset_readers.dataset_utils import to_bioul
from allennlp.data.fields import TextField, SequenceLabelField, Field, MetadataField
from allennlp.data.instance import Instance
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
self._original_coding_scheme = "IOB1"

@overrides
def _read(self, file_path: str) -> Iterable[Instance]:
def _read(self, file_path: PathOrStr) -> Iterable[Instance]:
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)

Expand Down
17 changes: 7 additions & 10 deletions allennlp/data/dataset_readers/dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,25 @@ def read(self, file_path: DatasetReaderInput) -> Iterator[Instance]:
"""
Returns an iterator of instances that can be read from the file path.
"""
if not isinstance(file_path, str):
if isinstance(file_path, list):
file_path = [str(f) for f in file_path]
elif isinstance(file_path, dict):
file_path = {k: str(v) for k, v in file_path.items()}
else:
file_path = str(file_path)

for instance in self._multi_worker_islice(self._read(file_path)): # type: ignore
if self._worker_info is None:
# If not running in a subprocess, it's safe to apply the token_indexers right away.
self.apply_token_indexers(instance)
yield instance

def _read(self, file_path: str) -> Iterable[Instance]:
def _read(self, file_path) -> Iterable[Instance]:
"""
Reads the instances from the given file_path and returns them as an
Reads the instances from the given `file_path` and returns them as an
`Iterable`.
You are strongly encouraged to use a generator so that users can
read a dataset in a lazy way, if they so choose.
"""
# NOTE: `file_path` is left untyped here on purpose.
# Technically the type should be `DatasetReaderInput`, but many subclass
# implementations of `DatasetReader` define their `_read()` method to take a more
# specific type, such as just `str`. But that would be a type error
# according to mypy: https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
raise NotImplementedError

def text_to_instance(self, *inputs) -> Instance:
Expand Down
23 changes: 13 additions & 10 deletions allennlp/data/dataset_readers/interleaving_dataset_reader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, Mapping, Iterable
from typing import Dict, Mapping, Iterable, Union
import json

from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, PathOrStr
from allennlp.data.fields import MetadataField
from allennlp.data.instance import Instance

Expand Down Expand Up @@ -72,14 +72,17 @@ def _read_all_at_once(self, datasets: Mapping[str, Iterable[Instance]]) -> Itera
instance.fields[self._dataset_field_name] = MetadataField(key)
yield instance

def _read(self, file_path: str) -> Iterable[Instance]:
try:
file_paths = json.loads(file_path)
except json.JSONDecodeError:
raise ConfigurationError(
"the file_path for the InterleavingDatasetReader "
"needs to be a JSON-serialized dictionary {reader_name -> file_path}"
)
def _read(self, file_path: Union[str, Dict[str, PathOrStr]]) -> Iterable[Instance]:
if isinstance(file_path, str):
try:
file_paths = json.loads(file_path)
except json.JSONDecodeError:
raise ConfigurationError(
"the file_path for the InterleavingDatasetReader "
"needs to be a JSON-serialized dictionary {reader_name -> file_path}"
)
else:
file_paths = file_path

if file_paths.keys() != self._readers.keys():
raise ConfigurationError("mismatched keys")
Expand Down
6 changes: 3 additions & 3 deletions allennlp/data/dataset_readers/sharded_dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, PathOrStr
from allennlp.data.instance import Instance


Expand Down Expand Up @@ -52,7 +52,7 @@ def text_to_instance(self, *args, **kwargs) -> Instance:
"""
return self.reader.text_to_instance(*args, **kwargs) # type: ignore

def _read(self, file_path: str) -> Iterable[Instance]:
def _read(self, file_path: PathOrStr) -> Iterable[Instance]:
try:
maybe_extracted_archive = cached_path(file_path, extract_archive=True)
if not os.path.isdir(maybe_extracted_archive):
Expand All @@ -67,7 +67,7 @@ def _read(self, file_path: str) -> Iterable[Instance]:
raise ConfigurationError(f"No files found in {file_path}")
except FileNotFoundError:
# Not a local or remote archive, so treat as a glob.
shards = glob.glob(file_path)
shards = glob.glob(str(file_path))
if not shards:
raise ConfigurationError(f"No files found matching {file_path}")

Expand Down
10 changes: 5 additions & 5 deletions tests/data/dataset_readers/interleaving_dataset_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def test_round_robin(self):
reader = InterleavingDatasetReader(readers)
data_dir = self.FIXTURES_ROOT / "data"

file_path = f"""{{
"a": "{data_dir / 'babi.txt'}",
"b": "{data_dir / 'conll2003.txt'}",
"c": "{data_dir / 'conll2003.txt'}"
}}"""
file_path = {
"a": data_dir / "babi.txt",
"b": data_dir / "conll2003.txt",
"c": data_dir / "conll2003.txt",
}

instances = list(reader.read(file_path))
first_three_keys = {instance.fields["dataset"].metadata for instance in instances[:3]}
Expand Down

0 comments on commit 0f00d4d

Please sign in to comment.