Skip to content

Commit

Permalink
feat(dataset): refactor DatasetTag (#2232)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-alisafaee committed Aug 19, 2021
1 parent 0eb835b commit 00b9afa
Show file tree
Hide file tree
Showing 21 changed files with 440 additions and 398 deletions.
12 changes: 6 additions & 6 deletions renku/cli/dataset.py
Expand Up @@ -406,6 +406,7 @@
from renku.cli.utils.callback import ClickCallback
from renku.core import errors
from renku.core.commands.dataset import (
add_dataset_tag_command,
add_to_dataset,
create_dataset,
edit_dataset,
Expand All @@ -414,11 +415,10 @@
import_dataset,
list_datasets,
list_files,
list_tags,
list_tags_command,
remove_dataset,
remove_dataset_tags,
remove_dataset_tags_command,
show_dataset,
tag_dataset,
update_datasets,
)
from renku.core.commands.format.dataset_files import DATASET_FILES_COLUMNS, DATASET_FILES_FORMATS
Expand Down Expand Up @@ -644,7 +644,7 @@ def remove(name):
@click.option("--force", is_flag=True, help="Allow overwriting existing tags.")
def tag(name, tag, description, force):
"""Create a tag for a dataset."""
tag_dataset().build().execute(name, tag, description, force=force)
add_dataset_tag_command().build().execute(name=name, tag=tag, description=description, force=force)
click.secho("OK", fg="green")


Expand All @@ -653,7 +653,7 @@ def tag(name, tag, description, force):
@click.argument("tags", nargs=-1)
def remove_tags(name, tags):
"""Remove tags from a dataset."""
remove_dataset_tags().build().execute(name, tags)
remove_dataset_tags_command().build().execute(name=name, tags=tags)
click.secho("OK", fg="green")


Expand All @@ -662,7 +662,7 @@ def remove_tags(name, tags):
@click.option("--format", type=click.Choice(DATASET_TAGS_FORMATS), default="tabular", help="Choose an output format.")
def ls_tags(name, format):
"""List all tags of a dataset."""
result = list_tags().lock_dataset().build().execute(name, format)
result = list_tags_command().lock_dataset().build().execute(name=name, format=format)
click.echo(result.output)


Expand Down
177 changes: 68 additions & 109 deletions renku/core/commands/dataset.py
Expand Up @@ -38,7 +38,9 @@
from renku.core.management import LocalClient
from renku.core.management.command_builder import inject
from renku.core.management.command_builder.command import Command
from renku.core.management.dataset import get_dataset
from renku.core.management.dataset.datasets_provenance import DatasetsProvenance
from renku.core.management.dataset.tag import add_dataset_tag, remove_dataset_tags
from renku.core.management.datasets import DATASET_METADATA_PATHS
from renku.core.management.interface.database_gateway import IDatabaseGateway
from renku.core.metadata.immutable import DynamicProxy
Expand All @@ -60,13 +62,19 @@
@inject.autoparams()
def _list_datasets(datasets_provenance: DatasetsProvenance, format=None, columns=None):
"""List all datasets."""
datasets = [DynamicProxy(d) for d in datasets_provenance.datasets]
for dataset in datasets:
tags = datasets_provenance.get_all_tags(dataset)
dataset.tags = tags
dataset.tags_csv = ",".join(tag.name for tag in tags)

if format is None:
return list(datasets_provenance.datasets)
return list(datasets)

if format not in DATASETS_FORMATS:
raise UsageError("format not supported")

return DATASETS_FORMATS[format](datasets_provenance.datasets, columns=columns)
return DATASETS_FORMATS[format](datasets, columns=columns)


def list_datasets():
Expand Down Expand Up @@ -386,18 +394,19 @@ def remove_dataset():


@inject.autoparams()
def _export_dataset(name, provider_name, publish, tag, client: LocalClient, **kwargs):
def _export_dataset(
name, provider_name, publish, tag, client: LocalClient, datasets_provenance: DatasetsProvenance, **kwargs
):
"""Export data to 3rd party provider.
:raises: ``ValueError``, ``HTTPError``, ``InvalidAccessToken``,
``DatasetNotFound``
:raises: ``ParameterError``, ``HTTPError``, ``InvalidAccessToken``, ``DatasetNotFound``
"""
provider_name = provider_name.lower()

# TODO: all these callbacks are ugly, improve in #737
config_key_secret = "access_token"

dataset_ = client.get_dataset(name, strict=True)
dataset = client.get_dataset(name, strict=True, immutable=True)

try:
provider = ProviderFactory.from_id(provider_name)
Expand All @@ -407,60 +416,43 @@ def _export_dataset(name, provider_name, publish, tag, client: LocalClient, **kw
provider.set_parameters(**kwargs)

selected_tag = None
selected_commit = client.repo.head.commit
tags = datasets_provenance.get_all_tags(dataset)

if tag:
selected_tag = next((t for t in dataset_.tags if t.name == tag), None)
selected_tag = next((t for t in tags if t.name == tag), None)

if not selected_tag:
raise ValueError("Tag {} not found".format(tag))

selected_commit = selected_tag.commit
elif dataset_.tags and len(dataset_.tags) > 0:
tag_result = _prompt_tag_selection(dataset_.tags)

if tag_result:
selected_tag = tag_result
selected_commit = tag_result.commit

# FIXME: This won't work and needs to be fixed in #renku-python/issues/2210
# If the tag is created automatically for imported datasets, it
# does not have the dataset yet and we need to use the next commit
with client.with_commit(selected_commit):
test_ds = client.get_dataset(name)
if not test_ds:
commits = client.dataset_commits(dataset_)
next_commit = selected_commit
for commit in commits:
if commit.hexsha == selected_commit:
selected_commit = next_commit.hexsha
break
next_commit = commit

with client.with_commit(selected_commit):
dataset_ = client.get_dataset(name)
if not dataset_:
raise DatasetNotFound(name=name)

dataset_.data_dir = get_dataset_data_dir(client, dataset_)

access_token = client.get_value(provider_name, config_key_secret)
exporter = provider.get_exporter(dataset_, access_token=access_token)

if access_token is None:
access_token = _prompt_access_token(exporter)

if access_token is None or len(access_token) == 0:
raise InvalidAccessToken()

client.set_value(provider_name, config_key_secret, access_token, global_only=True)
exporter.set_access_token(access_token)

try:
destination = exporter.export(publish=publish, tag=selected_tag, client=client)
except errors.AuthenticationError:
client.remove_value(provider_name, config_key_secret, global_only=True)
raise
raise errors.ParameterError(f"Tag '{tag}' not found for dataset '{name}'")
elif tags:
selected_tag = _prompt_tag_selection(tags)

if selected_tag:
dataset = datasets_provenance.get_by_id(selected_tag.dataset_id, immutable=True)

if not dataset:
raise DatasetNotFound(message=f"Cannot find dataset with id: '{selected_tag.dataset_id}'")

data_dir = get_dataset_data_dir(client, dataset)
dataset = DynamicProxy(dataset)
dataset.data_dir = data_dir

access_token = client.get_value(provider_name, config_key_secret)
exporter = provider.get_exporter(dataset, access_token=access_token)

if access_token is None:
access_token = _prompt_access_token(exporter)

if access_token is None or len(access_token) == 0:
raise InvalidAccessToken()

client.set_value(provider_name, config_key_secret, access_token, global_only=True)
exporter.set_access_token(access_token)

try:
destination = exporter.export(publish=publish, tag=selected_tag, client=client)
except errors.AuthenticationError:
client.remove_value(provider_name, config_key_secret, global_only=True)
raise

communication.echo(f"Exported to: {destination}")

Expand Down Expand Up @@ -558,12 +550,7 @@ def _import_dataset(

if dataset.version:
tag_name = re.sub("[^a-zA-Z0-9.-_]", "_", dataset.version)
_tag_dataset_helper(
dataset=dataset,
tag=tag_name,
description=f"Tag {dataset.version} created by renku import",
update_provenance=False,
)
add_dataset_tag(dataset=dataset, tag=tag_name, description=f"Tag {dataset.version} created by renku import")
else:
name = name or dataset.name

Expand Down Expand Up @@ -815,72 +802,47 @@ def _filter(
return sorted(records, key=lambda r: r.date_added)


@inject.autoparams()
def _tag_dataset(name, tag, description, client: LocalClient, update_provenance=True, force=False):
def _add_dataset_tag(name, tag, description, force=False):
"""Creates a new tag for a dataset."""
dataset = client.get_dataset(name, strict=True)
_tag_dataset_helper(
dataset=dataset, tag=tag, description=description, update_provenance=update_provenance, force=force
)
dataset = get_dataset(name, strict=True)
add_dataset_tag(dataset=dataset, tag=tag, description=description, force=force)


@inject.autoparams()
def _tag_dataset_helper(
dataset,
tag,
description,
client: LocalClient,
datasets_provenance: DatasetsProvenance,
update_provenance=True,
force=False,
):
try:
client.add_dataset_tag(dataset, tag, description, force)
except ValueError as e:
raise ParameterError(e)
else:
if update_provenance:
datasets_provenance.add_or_update(dataset)


def tag_dataset():
def add_dataset_tag_command():
"""Command for creating a new tag for a dataset."""
command = Command().command(_tag_dataset).lock_dataset().with_database(write=True)
command = Command().command(_add_dataset_tag).lock_dataset().with_database(write=True)
return command.require_migration().with_commit(commit_only=DATASET_METADATA_PATHS)


@inject.autoparams()
def _remove_dataset_tags(name, tags, client: LocalClient, datasets_provenance: DatasetsProvenance):
def _remove_dataset_tags(name, tags):
"""Removes tags from a dataset."""
dataset = client.get_dataset(name, strict=True)

try:
client.remove_dataset_tags(dataset, tags)
except ValueError as e:
raise ParameterError(e)
else:
datasets_provenance.add_or_update(dataset)
dataset = get_dataset(name, strict=True)
remove_dataset_tags(dataset, tags)


def remove_dataset_tags():
def remove_dataset_tags_command():
"""Command for removing tags from a dataset."""
command = Command().command(_remove_dataset_tags).lock_dataset().with_database(write=True)
return command.require_migration().with_commit(commit_only=DATASET_METADATA_PATHS)


@inject.autoparams()
def _list_tags(name, format, client: LocalClient):
def _list_dataset_tags(name, format, datasets_provenance: DatasetsProvenance):
"""List all tags for a dataset."""
dataset = client.get_dataset(name, strict=True)
dataset = get_dataset(name, strict=True)

tags = sorted(dataset.tags, key=lambda t: t.date_created)
tags = datasets_provenance.get_all_tags(dataset)
tags = sorted(tags, key=lambda t: t.date_created)
tags = [DynamicProxy(t) for t in tags]
for tag in tags:
tag.dataset = dataset.title

return DATASET_TAGS_FORMATS[format](tags)


def list_tags():
def list_tags_command():
"""Command for listing a dataset's tags."""
return Command().command(_list_tags).with_database().require_migration()
return Command().command(_list_dataset_tags).with_database().require_migration()


def _prompt_access_token(exporter):
Expand All @@ -897,13 +859,10 @@ def _prompt_access_token(exporter):

def _prompt_tag_selection(tags) -> Optional[DatasetTag]:
"""Prompt user to chose a tag or <HEAD>."""
# Prompt user to select a tag to export
tags = sorted(tags, key=lambda t: t.date_created)

text_prompt = "Tag to export: \n\n<HEAD>\t[1]\n"

text_prompt += "\n".join("{}\t[{}]".format(t.name, i) for i, t in enumerate(tags, start=2))

text_prompt += "\n".join(f"{t.name}\t[{i}]" for i, t in enumerate(tags, start=2))
text_prompt += "\n\nTag"
selection = communication.prompt(text_prompt, type=click.IntRange(1, len(tags) + 1), default=1)

Expand Down
8 changes: 7 additions & 1 deletion renku/core/commands/format/dataset_tags.py
Expand Up @@ -30,7 +30,13 @@ def tabular(tags):
return tabulate(
tags,
headers=OrderedDict(
(("date_created", "created"), ("name", None), ("description", None), ("dataset", None), ("commit", None))
(
("date_created", "created"),
("name", None),
("description", None),
("dataset", None),
("dataset_id", "dataset id"),
)
),
# workaround for tabulate issue 181
# https://bitbucket.org/astanin/python-tabulate/issues/181/disable_numparse-fails-on-empty-input
Expand Down
4 changes: 3 additions & 1 deletion renku/core/commands/providers/dataverse.py
Expand Up @@ -46,6 +46,7 @@
from renku.core.models.provenance.agent import PersonSchema
from renku.core.utils.doi import extract_doi, is_doi
from renku.core.utils.file_size import bytes_to_unit
from renku.core.utils.git import get_content
from renku.core.utils.requests import retry

DATAVERSE_API_PATH = "api/v1"
Expand Down Expand Up @@ -424,7 +425,8 @@ def export(self, publish, client=None, **kwargs):
path = (client.path / file.entity.path).relative_to(self.dataset.data_dir)
except ValueError:
path = Path(file.entity.path)
deposition.upload_file(full_path=client.path / file.entity.path, path_in_dataset=path)
filepath = get_content(repo=client.repo, path=file.entity.path, checksum=file.entity.checksum)
deposition.upload_file(full_path=filepath, path_in_dataset=path)
progressbar.update(1)

if publish:
Expand Down
4 changes: 3 additions & 1 deletion renku/core/commands/providers/olos.py
Expand Up @@ -31,6 +31,7 @@
from renku.core.management import LocalClient
from renku.core.management.command_builder import inject
from renku.core.utils import communication
from renku.core.utils.git import get_content
from renku.core.utils.requests import retry


Expand Down Expand Up @@ -116,7 +117,8 @@ def export(self, publish, client=None, **kwargs):
path = (client.path / file.entity.path).relative_to(self.dataset.data_dir)
except ValueError:
path = Path(file.entity.path)
deposition.upload_file(full_path=client.path / file.entity.path, path_in_dataset=path)
filepath = get_content(repo=client.repo, path=file.entity.path, checksum=file.entity.checksum)
deposition.upload_file(full_path=filepath, path_in_dataset=path)
communication.update_progress(progress_text, amount=1)
finally:
communication.finalize_progress(progress_text)
Expand Down

0 comments on commit 00b9afa

Please sign in to comment.