Skip to content

Commit

Permalink
fix: various bugs in dataset import (#882)
Browse files Browse the repository at this point in the history
* fix: various bug fixes

* refactor: parallel downloads
  • Loading branch information
m-alisafaee committed Jan 9, 2020
1 parent 003b277 commit be28bf5
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 250 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def client(project):
@pytest.fixture
def dataset(client):
"""Create a dataset."""
with client.with_dataset(name='dataset', create=True) as dataset:
with client.with_dataset('dataset', create=True) as dataset:
dataset.creator = [{
'affiliation': 'xxx',
'email': 'me@example.com',
Expand Down
116 changes: 25 additions & 91 deletions renku/cli/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,7 @@
.. note:: The ``unlink`` command does not delete files,
only the dataset record.
"""
import multiprocessing as mp
import os
from functools import partial
from pathlib import Path
from time import sleep

import click
import editor
Expand All @@ -261,6 +257,7 @@
from renku.core.commands.format.dataset_tags import DATASET_TAGS_FORMATS
from renku.core.commands.format.datasets import DATASETS_FORMATS
from renku.core.errors import DatasetNotFound, InvalidAccessToken
from renku.core.management.datasets import DownloadProgressCallback


def prompt_access_token(exporter):
Expand Down Expand Up @@ -297,71 +294,6 @@ def prompt_tag_selection(tags):
return None


def download_file_with_progress(extract, data_folder, file, chunk_size=16384):
"""Download a file with progress tracking."""
global current_process_position

local_filename = Path(file.filename).name
download_to = Path(data_folder) / Path(local_filename)

def extract_dataset(data_folder_, filename):
"""Extract downloaded dataset."""
import patoolib
filepath = Path(data_folder_) / Path(filename)
patoolib.extract_archive(filepath, outdir=data_folder_)
filepath.unlink()

def stream_to_file(request):
"""Stream bytes to file."""
with open(str(download_to), 'wb') as f_:
scaling_factor = 1e-6
unit = 'MB'

# We round sizes to 0.1, files smaller than 1e5 would
# get rounded to 0, so we display bytes instead
if file.filesize < 1e5:
scaling_factor = 1.0
unit = 'B'

total = round(file.filesize * scaling_factor, 1)
progressbar_ = tqdm(
total=total,
position=current_process_position,
desc=file.filename[:32],
bar_format=(
'{{percentage:3.0f}}% '
'{{n_fmt}}{unit}/{{total_fmt}}{unit}| '
'{{bar}} | {{desc}}'.format(unit=unit)
),
leave=False,
)

try:
bytes_downloaded = 0
for chunk in request.iter_content(chunk_size=chunk_size):
if chunk: # remove keep-alive chunks
f_.write(chunk)
bytes_downloaded += chunk_size
progressbar_.n = min(
float(
'{0:.1f}'.format(
bytes_downloaded * scaling_factor
)
), total
)
progressbar_.update(0)
finally:
sleep(0.1)
progressbar_.close()

if extract:
extract_dataset(data_folder, local_filename)

with requests.get(file.url.geturl(), stream=True) as r:
r.raise_for_status()
stream_to_file(r)


@click.group(invoke_without_command=True)
@click.option('--revision', default=None)
@click.option('--datadir', default='data', type=click.Path(dir_okay=True))
Expand Down Expand Up @@ -646,37 +578,39 @@ def import_(uri, short_name, extract):
Supported providers: [Zenodo, Dataverse]
"""
manager = mp.Manager()
id_queue = manager.Queue()

pool_size = min(int(os.getenv('RENKU_POOL_SIZE', mp.cpu_count() // 2)), 4)

for i in range(pool_size):
id_queue.put(i)

def _init(lock, id_queue):
"""Set up tqdm lock and worker process index.
See https://stackoverflow.com/a/42817946
Fixes tqdm line position when |files| > terminal-height
so only |workers| progressbars are shown at a time
"""
global current_process_position
current_process_position = id_queue.get()
tqdm.set_lock(lock)

import_dataset(
uri=uri,
short_name=short_name,
extract=extract,
with_prompt=True,
pool_init_fn=_init,
pool_init_args=(mp.RLock(), id_queue),
download_file_fn=download_file_with_progress
progress=_DownloadProgressbar
)
click.secho('OK', fg='green')


class _DownloadProgressbar(DownloadProgressCallback):
def __init__(self, description, total_size):
"""Default initializer."""
self._progressbar = tqdm(
total=total_size,
unit='iB',
unit_scale=True,
desc=description,
leave=False,
bar_format='{desc:.32}: {percentage:3.0f}%|{bar}{r_bar}'
)

def update(self, size):
"""Update the status."""
if self._progressbar:
self._progressbar.update(size)

def finalize(self):
"""Called once when the download is finished."""
if self._progressbar:
self._progressbar.close()


@dataset.command('update')
@click.argument('names', nargs=-1)
@click.option(
Expand Down
139 changes: 35 additions & 104 deletions renku/core/commands/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,12 @@
# limitations under the License.
"""Repository datasets management."""

import multiprocessing as mp
import os
import re
import tempfile
from collections import OrderedDict
from contextlib import contextmanager
from multiprocessing import freeze_support
from pathlib import Path
from urllib.parse import ParseResult

import click
import git
import requests
import yaml
from requests import HTTPError

Expand All @@ -46,7 +39,6 @@
from renku.core.models.provenance.agents import Person
from renku.core.models.refs import LinkReference
from renku.core.models.tabulate import tabulate
from renku.core.utils.doi import extract_doi
from renku.core.utils.urls import remove_credentials

from .client import pass_local_client
Expand All @@ -55,32 +47,6 @@
from .format.datasets import DATASETS_FORMATS


def default_download_file(extract, data_folder, file, chunk_size=16384):
"""Download a file."""
local_filename = Path(file.filename).name
download_to = Path(data_folder) / Path(local_filename)

def extract_dataset(data_folder_, filename):
"""Extract downloaded dataset."""
import patoolib
filepath = Path(data_folder_) / Path(filename)
patoolib.extract_archive(filepath, outdir=data_folder_)
filepath.unlink()

def stream_to_file(request):
"""Stream bytes to file."""
with open(str(download_to), 'wb') as f_:
for chunk in request.iter_content(chunk_size=chunk_size):
if chunk: # remove keep-alive chunks
f_.write(chunk)
if extract:
extract_dataset(data_folder, local_filename)

with requests.get(file.url.geturl(), stream=True) as r:
r.raise_for_status()
stream_to_file(r)


@pass_local_client(clean=False, commit=False)
def dataset_parent(client, revision, datadir, format, ctx=None):
"""Handle datasets subcommands."""
Expand Down Expand Up @@ -171,15 +137,24 @@ def add_file(
):
"""Add data file to a dataset."""
add_to_dataset(
client, urls, name, link, force, create, sources, destination, ref,
with_metadata, urlscontext
client=client,
urls=urls,
short_name=name,
link=link,
force=force,
create=create,
sources=sources,
destination=destination,
ref=ref,
with_metadata=with_metadata,
urlscontext=urlscontext
)


def add_to_dataset(
client,
urls,
name,
short_name,
link=False,
force=False,
create=False,
Expand All @@ -189,6 +164,9 @@ def add_to_dataset(
with_metadata=None,
urlscontext=contextlib.nullcontext,
commit_message=None,
extract=False,
all_at_once=False,
progress=None,
):
"""Add data to a dataset."""
if len(urls) == 0:
Expand All @@ -198,14 +176,9 @@ def add_to_dataset(
'Cannot add multiple URLs with --source or --destination'
)

# check for identifier before creating the dataset
identifier = extract_doi(
with_metadata.identifier
) if with_metadata else None

try:
with client.with_dataset(
name=name, identifier=identifier, create=create
short_name=short_name, create=create
) as dataset:
with urlscontext(urls) as bar:
warning_message = client.add_data_to_dataset(
Expand All @@ -215,25 +188,20 @@ def add_to_dataset(
force=force,
sources=sources,
destination=destination,
ref=ref
ref=ref,
extract=extract,
all_at_once=all_at_once,
progress=progress,
)

if warning_message:
click.echo(WARNING + warning_message)

if with_metadata:
for file_ in with_metadata.files:
for added_ in dataset.files:

if added_.path.endswith(file_.filename):
if isinstance(file_.url, ParseResult):
file_.url = file_.url.geturl()

file_.path = added_.path
file_.url = remove_credentials(file_.url)
file_.creator = with_metadata.creator
file_._label = added_._label
file_.commit = added_.commit
for file_ in dataset.files:
file_.creator = with_metadata.creator
# dataset has the correct list of files
with_metadata.files = dataset.files

dataset.update_metadata(with_metadata)

Expand All @@ -242,7 +210,7 @@ def add_to_dataset(
'Dataset "{0}" does not exist.\n'
'Use "renku dataset create {0}" to create the dataset or retry '
'"renku dataset add {0}" command with "--create" option for '
'automatic dataset creation.'.format(name)
'automatic dataset creation.'.format(short_name)
)
except (FileNotFoundError, git.exc.NoSuchPathError) as e:
raise ParameterError(
Expand Down Expand Up @@ -446,10 +414,8 @@ def import_dataset(
short_name='',
extract=False,
with_prompt=False,
pool_init_fn=None,
pool_init_args=None,
download_file_fn=default_download_file,
commit_message=None,
progress=None,
):
"""Import data from a 3rd party provider."""
provider, err = ProviderFactory.from_uri(uri)
Expand Down Expand Up @@ -500,53 +466,18 @@ def import_dataset(
dataset.name, dataset.version
)

dataset.short_name = short_name

client.create_dataset(name=dataset.name, short_name=short_name)

data_folder = tempfile.mkdtemp()

pool_size = min(
int(os.getenv('RENKU_POOL_SIZE',
mp.cpu_count() // 2)), 4
)

freeze_support() # Windows support

pool = mp.Pool(
pool_size,
# Windows support
initializer=pool_init_fn,
initargs=pool_init_args
)

processing = [
pool.apply_async(
download_file_fn, args=(
extract,
data_folder,
file_,
)
) for file_ in files
]

try:
for p in processing:
p.get() # Will internally do the wait() as well.

except HTTPError as e:
raise ParameterError((
'Could not process {0}.\n'
'URI not found.'.format(e.request.url)
))
pool.close()

dataset.url = remove_credentials(dataset.url)

add_to_dataset(
client,
urls=[str(p) for p in Path(data_folder).glob('*')],
name=short_name,
with_metadata=dataset
urls=[f.url for f in files],
short_name=short_name,
create=True,
with_metadata=dataset,
force=True,
extract=extract,
all_at_once=True,
progress=progress,
)

if dataset.version:
Expand Down
Loading

0 comments on commit be28bf5

Please sign in to comment.