Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
ensure Vocab.from_files and ShardedDatasetReader can handle archives (#…
Browse files Browse the repository at this point in the history
…4371)

* ensure Vocab.from_files can handle archives

* handle archive with ShardedDatasetReader

* through helpful ConfigurationError

* update docstring
  • Loading branch information
epwalsh committed Jun 22, 2020
1 parent 20afe6c commit ffc5184
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 17 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -17,6 +17,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- A method to ModelTestCase for running basic model tests when you aren't using config files.

### Added

- Added an option to `file_utils.cached_path` to automatically extract archives.
- Added the ability to pass an archive file instead of a local directory to `Vocab.from_files`.
- Added the ability to pass an archive file instead of a glob to `ShardedDatasetReader`.

## [v1.0.0](https://github.com/allenai/allennlp/releases/tag/v1.0.0) - 2020-06-16

### Fixed
Expand Down
77 changes: 71 additions & 6 deletions allennlp/common/file_utils.py
Expand Up @@ -12,6 +12,9 @@
from typing import Optional, Tuple, Union, IO, Callable, Set, List
from hashlib import sha256
from functools import wraps
from zipfile import ZipFile, is_zipfile
import tarfile
import shutil

import boto3
import botocore
Expand Down Expand Up @@ -61,7 +64,7 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename


def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
"""
Return the url and etag (which may be `None`) stored for `filename`.
Raise `FileNotFoundError` if `filename` or its stored metadata do not exist.
Expand All @@ -85,34 +88,96 @@ def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
return url, etag


def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
def cached_path(
url_or_filename: Union[str, Path],
cache_dir: Union[str, Path] = None,
extract_archive: bool = False,
force_extract: bool = False,
) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
# Parameters
url_or_filename : `Union[str, Path]`
A URL or local file to parse and possibly download.
cache_dir : `Union[str, Path]`, optional (default = `None`)
The directory to cache downloads.
extract_archive : `bool`, optional (default = `False`)
If `True`, then zip or tar.gz archives will be automatically extracted.
In which case the directory is returned.
force_extract : `bool`, optional (default = `False`)
If `True` and the file is an archive file, it will be extracted regardless
of whether or not the extracted directory already exists.
"""
if cache_dir is None:
cache_dir = CACHE_DIRECTORY

if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)

url_or_filename = os.path.expanduser(url_or_filename)
parsed = urlparse(url_or_filename)

file_path: str
extraction_path: Optional[str] = None

if parsed.scheme in ("http", "https", "s3"):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
file_path = get_from_cache(url_or_filename, cache_dir)

if extract_archive and (is_zipfile(file_path) or tarfile.is_tarfile(file_path)):
# This is the path the file should be extracted to.
# For example ~/.allennlp/cache/234234.21341 -> ~/.allennlp/cache/234234.21341-extracted
extraction_path = file_path + "-extracted"

elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
file_path = url_or_filename

if extract_archive and (is_zipfile(file_path) or tarfile.is_tarfile(file_path)):
# This is the path the file should be extracted to.
# For example model.tar.gz -> model-tar-gz-extracted
extraction_dir, extraction_name = os.path.split(file_path)
extraction_name = extraction_name.replace(".", "-") + "-extracted"
extraction_path = os.path.join(extraction_dir, extraction_name)

elif parsed.scheme == "":
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))

else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))

if extraction_path is not None:
# No need to extract again.
if os.path.isdir(extraction_path) and os.listdir(extraction_path) and not force_extract:
return extraction_path

# Extract it.
with FileLock(file_path + ".lock"):
shutil.rmtree(extraction_path, ignore_errors=True)
os.makedirs(extraction_path)
if is_zipfile(file_path):
with ZipFile(file_path, "r") as zip_file:
zip_file.extractall(extraction_path)
zip_file.close()
else:
tar_file = tarfile.open(file_path)
tar_file.extractall(extraction_path)
tar_file.close()

return extraction_path

return file_path


def is_url_or_existing_file(url_or_filename: Union[str, Path, None]) -> bool:
"""
Expand Down Expand Up @@ -226,7 +291,7 @@ def _http_get(url: str, temp_file: IO) -> None:
progress.close()


def _find_latest_cached(url: str, cache_dir: str) -> Optional[str]:
def _find_latest_cached(url: str, cache_dir: Union[str, Path]) -> Optional[str]:
filename = url_to_filename(url)
cache_path = os.path.join(cache_dir, filename)
candidates: List[Tuple[str, float]] = []
Expand Down Expand Up @@ -283,7 +348,7 @@ def __exit__(self, exc_type, exc_value, traceback):


# TODO(joelgrus): do we want to do checksums or anything like that?
def get_from_cache(url: str, cache_dir: str = None) -> str:
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
Expand Down
31 changes: 27 additions & 4 deletions allennlp/data/dataset_readers/sharded_dataset_reader.py
@@ -1,9 +1,12 @@
import glob
import logging
import os
import torch
from typing import Iterable

from allennlp.common import util
from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance

Expand All @@ -15,9 +18,12 @@
class ShardedDatasetReader(DatasetReader):
"""
Wraps another dataset reader and uses it to read from multiple input files.
Note that in this case the `file_path` passed to `read()` should be a glob,
and that the dataset reader will return instances from all files matching
the glob.
Note that in this case the `file_path` passed to `read()` should either be a glob path
or a path or URL to an archive file ('.zip' or '.tar.gz').
The dataset reader will return instances from all files matching the glob, or all
files within the archive.
The order the files are processed in is deterministic to enable the
instances to be filtered according to worker rank in the distributed case.
Expand Down Expand Up @@ -50,7 +56,24 @@ def text_to_instance(self, *args, **kwargs) -> Instance:
return self.reader.text_to_instance(*args, **kwargs) # type: ignore

def _read(self, file_path: str) -> Iterable[Instance]:
shards = glob.glob(file_path)
try:
maybe_extracted_archive = cached_path(file_path, extract_archive=True)
if not os.path.isdir(maybe_extracted_archive):
# This isn't a directory, so `file_path` is just a file.
raise ConfigurationError(f"{file_path} should be an archive or directory")
shards = [
os.path.join(maybe_extracted_archive, p)
for p in os.listdir(maybe_extracted_archive)
if not p.startswith(".")
]
if not shards:
raise ConfigurationError(f"No files found in {file_path}")
except FileNotFoundError:
# Not a local or remote archive, so treat as a glob.
shards = glob.glob(file_path)
if not shards:
raise ConfigurationError(f"No files found matching {file_path}")

# Ensure a consistent order.
shards.sort()

Expand Down
20 changes: 17 additions & 3 deletions allennlp/data/vocabulary.py
Expand Up @@ -13,10 +13,11 @@

from filelock import FileLock

from allennlp.common.util import namespace_match
from allennlp.common import Registrable
from allennlp.common.file_utils import cached_path
from allennlp.common.checks import ConfigurationError
from allennlp.common.tqdm import Tqdm
from allennlp.common.util import namespace_match

if TYPE_CHECKING:
from allennlp.data import instance as adi # noqa
Expand Down Expand Up @@ -311,17 +312,30 @@ def from_files(
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
) -> "Vocabulary":
"""
Loads a `Vocabulary` that was serialized using `save_to_files`.
Loads a `Vocabulary` that was serialized either using `save_to_files` or inside
a model archive file.
# Parameters
directory : `str`
The directory containing the serialized vocabulary.
The directory or archive file containing the serialized vocabulary.
"""
logger.info("Loading token dictionary from %s.", directory)
padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN

if not os.path.isdir(directory):
base_directory = cached_path(directory, extract_archive=True)
# For convenience we'll check for a 'vocabulary' subdirectory of the archive.
# That way you can use model archives directly.
vocab_subdir = os.path.join(base_directory, "vocabulary")
if os.path.isdir(vocab_subdir):
directory = vocab_subdir
elif os.path.isdir(base_directory):
directory = base_directory
else:
raise ConfigurationError(f"{directory} is neither a directory nor an archive")

# We use a lock file to avoid race conditions where multiple processes
# might be reading/writing from/to the same vocab files at once.
with FileLock(os.path.join(directory, ".lock")):
Expand Down
Binary file added test_fixtures/data/vocab.tar.gz
Binary file not shown.
Binary file added test_fixtures/data/vocab.zip
Binary file not shown.
74 changes: 74 additions & 0 deletions tests/common/file_utils_test.py
Expand Up @@ -3,6 +3,7 @@
import pathlib
import json
import time
import shutil

import pytest
import responses
Expand Down Expand Up @@ -254,6 +255,79 @@ def test_open_compressed(self):
assert compressed_lines == uncompressed_lines


class TestCachedPathWithArchive(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
self.tar_file = self.TEST_DIR / "utf-8.tar.gz"
shutil.copyfile(
self.FIXTURES_ROOT / "utf-8_sample" / "archives" / "utf-8.tar.gz", self.tar_file
)
self.zip_file = self.TEST_DIR / "utf-8.zip"
shutil.copyfile(
self.FIXTURES_ROOT / "utf-8_sample" / "archives" / "utf-8.zip", self.zip_file
)

@staticmethod
def check_extracted(extracted: str):
assert os.path.isdir(extracted)
assert os.path.exists(os.path.join(extracted, "dummy.txt"))
assert os.path.exists(os.path.join(extracted, "folder/utf-8_sample.txt"))

def test_cached_path_extract_local_tar(self):
extracted = cached_path(self.tar_file, cache_dir=self.TEST_DIR, extract_archive=True)
assert os.path.basename(extracted) == "utf-8-tar-gz-extracted"
self.check_extracted(extracted)

def test_cached_path_extract_local_zip(self):
extracted = cached_path(self.zip_file, cache_dir=self.TEST_DIR, extract_archive=True)
assert os.path.basename(extracted) == "utf-8-zip-extracted"
self.check_extracted(extracted)

@responses.activate
def test_cached_path_extract_remote_tar(self):
url = "http://fake.datastore.com/utf-8.tar.gz"
byt = open(self.tar_file, "rb").read()

responses.add(
responses.GET,
url,
body=byt,
status=200,
content_type="application/tar+gzip",
stream=True,
headers={"Content-Length": str(len(byt))},
)
responses.add(
responses.HEAD, url, status=200, headers={"ETag": "fake-etag"},
)

extracted = cached_path(url, cache_dir=self.TEST_DIR, extract_archive=True)
assert extracted.endswith("-extracted")
self.check_extracted(extracted)

@responses.activate
def test_cached_path_extract_remote_zip(self):
url = "http://fake.datastore.com/utf-8.zip"
byt = open(self.zip_file, "rb").read()

responses.add(
responses.GET,
url,
body=byt,
status=200,
content_type="application/zip",
stream=True,
headers={"Content-Length": str(len(byt))},
)
responses.add(
responses.HEAD, url, status=200, headers={"ETag": "fake-etag"},
)

extracted = cached_path(url, cache_dir=self.TEST_DIR, extract_archive=True)
assert extracted.endswith("-extracted")
self.check_extracted(extracted)


class TestCacheFile(AllenNlpTestCase):
def test_temp_file_removed_on_error(self):
cache_filename = self.TEST_DIR / "cache_file"
Expand Down
26 changes: 22 additions & 4 deletions tests/data/dataset_readers/sharded_dataset_reader_test.py
@@ -1,4 +1,7 @@
from collections import Counter
import glob
import os
import tarfile
from typing import Tuple

from allennlp.common.testing import AllenNlpTestCase
Expand Down Expand Up @@ -33,12 +36,21 @@ def setup_method(self) -> None:

self.identical_files_glob = str(self.TEST_DIR / "identical_*.tsv")

def test_sharded_read(self):
reader = ShardedDatasetReader(base_reader=self.base_reader)
# Also create an archive with all of these files to ensure that we can
# pass the archive directory.
current_dir = os.getcwd()
os.chdir(self.TEST_DIR)
self.archive_filename = self.TEST_DIR / "all_data.tar.gz"
with tarfile.open(self.archive_filename, "w:gz") as archive:
for file_path in glob.glob("identical_*.tsv"):
archive.add(file_path)
os.chdir(current_dir)

all_instances = []
self.reader = ShardedDatasetReader(base_reader=self.base_reader)

for instance in reader.read(self.identical_files_glob):
def read_and_check_instances(self, filepath: str):
all_instances = []
for instance in self.reader.read(filepath):
all_instances.append(instance)

# 100 files * 4 sentences / file
Expand All @@ -52,3 +64,9 @@ def test_sharded_read(self):
assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 100
assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 100
assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 100

def test_sharded_read_glob(self):
self.read_and_check_instances(self.identical_files_glob)

def test_sharded_read_archive(self):
self.read_and_check_instances(str(self.archive_filename))

0 comments on commit ffc5184

Please sign in to comment.