Skip to content

Commit

Permalink
re-re-factor progress bar!
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Oct 24, 2020
1 parent 4d3d552 commit a54c319
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 94 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ repos:
(?x)^(
aiida/engine/processes/calcjobs/calcjob.py|
aiida/tools/groups/paths.py|
aiida/tools/importexport/dbexport/__init__.py
aiida/tools/importexport/dbexport/__init__.py|
aiida/tools/importexport/common/progress_reporter.py|
)$
- repo: local
Expand Down
9 changes: 7 additions & 2 deletions aiida/cmdline/commands/cmd_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
# pylint: disable=too-many-arguments,import-error,too-many-locals
"""`verdi export` command."""

from functools import partial
import os
import tempfile

Expand Down Expand Up @@ -95,8 +95,11 @@ def create(
their provenance, according to the rules outlined in the documentation.
You can modify some of those rules using options of this command.
"""
from tqdm import tqdm
from aiida.common.progress_reporter import set_progress_reporter
from aiida.tools.importexport import export, ExportFileFormat
from aiida.tools.importexport.common.exceptions import ArchiveExportError
from aiida.tools.importexport.common.config import BAR_FORMAT

entities = []

Expand Down Expand Up @@ -133,8 +136,10 @@ def create(
elif archive_format == 'tar.gz':
export_format = ExportFileFormat.TAR_GZIPPED

set_progress_reporter(partial(tqdm, bar_format=BAR_FORMAT, leave=verbose))

try:
export(entities, filename=output_file, file_format=export_format, verbose=verbose, **kwargs)
export(entities, filename=output_file, file_format=export_format, **kwargs)
except ArchiveExportError as exception:
echo.echo_critical(f'failed to write the archive file. Exception: {exception}')
else:
Expand Down
6 changes: 5 additions & 1 deletion aiida/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,9 @@
from .extendeddicts import *
from .links import *
from .log import *
from .progress_reporter import *

__all__ = (datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__)
__all__ = (
datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__ +
progress_reporter.__all__
)
103 changes: 103 additions & 0 deletions aiida/common/progress_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=global-statement,unused-argument
"""Provide a singleton progress reporter implementation.
The interface is inspired by `tqdm <https://github.com/tqdm/tqdm>`,
and indeed a valid implementation is::
from tqdm import tqdm
set_progress_reporter(tqdm)
"""
from contextlib import contextmanager
from typing import Any, Callable, ContextManager, Iterator, Optional

__all__ = ('get_progress_reporter', 'set_progress_reporter', 'progress_reporter_base', 'ProgressIncrementerBase')


class ProgressIncrementerBase:
"""A base class for incrementing a progress reporter."""

def set_description_str(self, text: Optional[str] = None, refresh: bool = True):
"""Set the text shown by the progress reporter.
:param text: The text to show
:param refresh: Force refresh of the progress reporter
"""

def update(self, n: int = 1): # pylint: disable=invalid-name
"""Update the progress counter.
:param n: Increment to add to the internal counter of iterations
"""


@contextmanager
def progress_reporter_base(*,
total: int,
desc: Optional[str] = None,
**kwargs: Any) -> Iterator[ProgressIncrementerBase]:
"""A context manager for providing a progress reporter for a process.
Example Usage::
with progress_reporter(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()
:param total: The number of expected iterations.
:param desc: A description of the process
:yield: A class for incrementing the progress reporter
"""
yield ProgressIncrementerBase()


PROGRESS_REPORTER = progress_reporter_base


def get_progress_reporter() -> Callable[..., ContextManager[Any]]:
"""Return the progress reporter
Example Usage::
with get_progress_reporter()(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()
"""
global PROGRESS_REPORTER
return PROGRESS_REPORTER # type: ignore


def set_progress_reporter(reporter: Optional[Callable[..., ContextManager[Any]]] = None):
"""Set the progress reporter implementation
:param reporter: A context manager for providing a progress reporter for a process.
If None, reset to default null reporter
The reporter should be a context manager that implements the
:func:`~aiida.common.progress_reporter.progress_reporter_base` interface.
Example Usage::
with get_progress_reporter()(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()
"""
global PROGRESS_REPORTER
PROGRESS_REPORTER = reporter or progress_reporter_base # type: ignore
Loading

0 comments on commit a54c319

Please sign in to comment.